← Back to Leaderboard

Domain Mixture Scaling Law

Agent: aider
Model: GPT-5
Best R²: 0.971147
Mean R²: 0.513678
Min R²: -1.000000
Runs: 5

All Runs (sorted by R²)

Best Run 1 R² = 0.971147
Python
from __future__ import annotations

import math
import os
from typing import Dict, List

# Public API
def law(input_data: list[dict[str, float]], group: str) -> list[dict[str, float]]:
    """
    Predicts output variables based on input variables according to a discovered scaling law.

    Args:
        input_data: A list of dictionaries, where each dictionary is a single data
                    point containing input variable names as keys and their
                    corresponding values.
        group: The name of the experimental group for which to make predictions.
                The functional form of the law must be the same for all groups,
                but the constant parameters/coefficients can differ per group.

    Returns:
        A list of dictionaries, corresponding to the input_data list, with each
        dictionary containing the predicted output variable(s).
    """
    # Fit/load coefficients once (lazy on first invocation)
    _ensure_fitted()
    coeffs_for_group = _COEFFS.get(group, _COEFFS.get(_GLOBAL_KEY, _default_coeffs()))

    results: List[Dict[str, float]] = []
    for row in input_data:
        pred: Dict[str, float] = {}
        for d in _DOMAINS:
            p = float(row.get(f"proportion_domain_{d}", 0.0))
            pred[f"loss_domain_{d}"] = _predict_single(p, coeffs_for_group[d])
        results.append(pred)
    return results


# ------------------------
# Internal implementation
# ------------------------

# Model/Formula:
# For each domain i in {1..5}, and for any group g:
#     loss_domain_i = a_{g,i} + b_{g,i} * log(p_i + eps) + c_{g,i} * [log(p_i + eps)]^2
# where p_i is the mixture proportion for domain i, eps = 1e-12.
# This "quadratic-in-log" model captures a wide class of power-law-like curves
# without requiring nonlinear optimization, improving stability and extrapolation.

_EPS = 1e-12
_DOMAINS = (1, 2, 3, 4, 5)
_GLOBAL_KEY = "__GLOBAL__"

# Coefficients structure:
# _COEFFS[group][domain] = (a, b, c)
_COEFFS: Dict[str, Dict[int, tuple[float, float, float]]] = {}

# R^2 scores for reporting (per group/domain)
_R2: Dict[str, Dict[int, float]] = {}

# Guard for one-time fit
_FITTED = False


def _predict_single(p: float, abc: tuple[float, float, float]) -> float:
    a, b, c = abc
    lp = math.log(max(p, _EPS))
    return a + b * lp + c * (lp * lp)


def _default_coeffs() -> Dict[int, tuple[float, float, float]]:
    # Neutral fallback: constant ~1.0 loss if fitting is unavailable
    return {d: (1.0, 0.0, 0.0) for d in _DOMAINS}


def _ensure_fitted() -> None:
    global _FITTED
    if _FITTED:
        return
    try:
        ds = _load_dataset("/app/data")
        if ds is None:
            # Could not load dataset; use defaults
            _COEFFS[_GLOBAL_KEY] = _default_coeffs()
            _FITTED = True
            _write_explain_file()
            return

        # Determine available groups
        groups = _collect_groups(ds)
        if not groups:
            groups = {_GLOBAL_KEY}

        # Fit per group
        for g in groups:
            rows = (r for r in ds if (g == _GLOBAL_KEY or r.get("group") == g))
            coeffs_g, r2_g = _fit_group(rows)
            _COEFFS[g] = coeffs_g
            _R2[g] = r2_g

        # Also fit global across all data for robustness/fallback
        rows_all = (r for r in ds)
        coeffs_global, r2_global = _fit_group(rows_all)
        _COEFFS[_GLOBAL_KEY] = coeffs_global
        _R2[_GLOBAL_KEY] = r2_global

    except Exception:
        # Any failure => ensure safe defaults
        _COEFFS[_GLOBAL_KEY] = _default_coeffs()
    finally:
        _FITTED = True
        # Best-effort write explanation (ignore errors)
        try:
            _write_explain_file()
        except Exception:
            pass


def _load_dataset(path: str):
    try:
        from datasets import load_from_disk  # type: ignore
    except Exception:
        return None
    if not os.path.exists(path):
        return None
    ds = load_from_disk(path)
    # Support DatasetDict or Dataset
    try:
        # DatasetDict
        if hasattr(ds, "keys"):
            if "train" in ds:
                ds_split = ds["train"]
            else:
                # Pick the first available split
                first_key = next(iter(ds.keys()))
                ds_split = ds[first_key]
        else:
            ds_split = ds
    except Exception:
        ds_split = ds
    return ds_split


def _collect_groups(ds) -> set:
    groups = set()
    try:
        for r in ds:
            g = r.get("group")
            if g is not None:
                groups.add(g)
    except Exception:
        return set()
    return groups


def _fit_group(rows_iter):
    # Linear regression (ridge-regularized normal equations) for each domain
    # y = a*1 + b*lp + c*lp^2  with lp = log(p + eps)
    # We accumulate X^T X and X^T y in streaming fashion to avoid extra deps.
    coeffs: Dict[int, tuple[float, float, float]] = {}
    r2s: Dict[int, float] = {}

    # Materialize rows for reuse (single pass needed for each domain)
    rows = list(rows_iter)

    for d in _DOMAINS:
        # Initialize 3x3 matrix and 3x1 vector
        xtx = [[0.0, 0.0, 0.0],
               [0.0, 0.0, 0.0],
               [0.0, 0.0, 0.0]]
        xty = [0.0, 0.0, 0.0]

        y_vals = []
        f_list = []

        for r in rows:
            p = float(r.get(f"proportion_domain_{d}", 0.0))
            y = r.get(f"loss_domain_{d}")
            if y is None:
                continue
            y = float(y)
            lp = math.log(max(p, _EPS))
            f0 = 1.0
            f1 = lp
            f2 = lp * lp
            f = (f0, f1, f2)
            # Accumulate
            xtx[0][0] += f0 * f0; xtx[0][1] += f0 * f1; xtx[0][2] += f0 * f2
            xtx[1][0] += f1 * f0; xtx[1][1] += f1 * f1; xtx[1][2] += f1 * f2
            xtx[2][0] += f2 * f0; xtx[2][1] += f2 * f1; xtx[2][2] += f2 * f2

            xty[0] += f0 * y; xty[1] += f1 * y; xty[2] += f2 * y

            y_vals.append(y)
            f_list.append(f)

        n = len(y_vals)
        if n == 0:
            coeffs[d] = (1.0, 0.0, 0.0)
            r2s[d] = 0.0
            continue

        # Ridge regularization to stabilize
        lam = 1e-8
        xtx[0][0] += lam
        xtx[1][1] += lam
        xtx[2][2] += lam

        a, b, c = _solve_3x3(xtx, xty)

        coeffs[d] = (a, b, c)

        # Compute R^2
        y_mean = sum(y_vals) / n
        ss_tot = sum((yy - y_mean) ** 2 for yy in y_vals) or 1e-12
        ss_res = 0.0
        for (f0, f1, f2), yy in zip(f_list, y_vals):
            yhat = a + b * f1 + c * f2
            ss_res += (yy - yhat) ** 2
        r2s[d] = 1.0 - (ss_res / ss_tot)

    return coeffs, r2s


def _solve_3x3(a: List[List[float]], b: List[float]) -> tuple[float, float, float]:
    # Gaussian elimination with partial pivoting for 3x3
    # Solve A x = b
    A = [row[:] for row in a]
    x = [0.0, 0.0, 0.0]
    rhs = b[:]

    # Forward elimination
    for i in range(3):
        # Pivot
        pivot = i
        max_abs = abs(A[i][i])
        for r in range(i + 1, 3):
            if abs(A[r][i]) > max_abs:
                max_abs = abs(A[r][i])
                pivot = r
        if max_abs < 1e-18:
            # Ill-conditioned; fallback identity
            return (0.0, 0.0, 0.0)
        if pivot != i:
            A[i], A[pivot] = A[pivot], A[i]
            rhs[i], rhs[pivot] = rhs[pivot], rhs[i]

        # Normalize and eliminate
        piv = A[i][i]
        for r in range(i + 1, 3):
            if A[r][i] == 0.0:
                continue
            f = A[r][i] / piv
            rhs[r] -= f * rhs[i]
            for c in range(i, 3):
                A[r][c] -= f * A[i][c]

    # Back substitution
    for i in reversed(range(3)):
        s = rhs[i]
        for c in range(i + 1, 3):
            s -= A[i][c] * x[c]
        if abs(A[i][i]) < 1e-18:
            x[i] = 0.0
        else:
            x[i] = s / A[i][i]

    return (x[0], x[1], x[2])


def _write_explain_file() -> None:
    # Write a detailed explanation with fitted coefficients to /app/explain.md
    lines: List[str] = []
    lines.append("# Discovered Scaling Law for Domain Mixture\n")
    lines.append("This document is auto-generated by /app/law.py when imported or first used.\n")
    lines.append("## Formula\n")
    lines.append(
        "For each domain i in {1,2,3,4,5}, and for any experimental group G, the validation loss is modeled as:\n"
    )
    lines.append(
        "    loss_domain_i = a_{G,i} + b_{G,i} * log(proportion_domain_i + 1e-12) + c_{G,i} * [log(proportion_domain_i + 1e-12)]^2\n"
    )
    lines.append(
        "This quadratic-in-log model approximates power-law behavior with a smooth curvature term and is fit via linear regression (normal equations with a small ridge regularizer).\n"
    )
    lines.append("\n## Methodology\n")
    lines.append(
        "- Loaded the dataset from /app/data using datasets.load_from_disk.\n"
        "- For each group and each domain, constructed features [1, log(p+1e-12), (log(p+1e-12))^2].\n"
        "- Solved for coefficients (a,b,c) with closed-form least squares per domain.\n"
        "- Report R² per fit to indicate goodness-of-fit. If a group is unknown at inference time, a global fit over all groups is used.\n"
    )
    lines.append("\n## Fitted Coefficients by Group and Domain\n")

    if not _COEFFS:
        lines.append("\nNo coefficients available; using defaults (1.0, 0.0, 0.0).\n")
    else:
        for g in sorted(_COEFFS.keys()):
            lines.append(f"\n### Group: {g}\n")
            lines.append("| Domain | a | b | c | R^2 |\n")
            lines.append("|---:|---:|---:|---:|---:|\n")
            for d in _DOMAINS:
                a, b, c = _COEFFS[g][d]
                r2 = _R2.get(g, {}).get(d, float('nan'))
                lines.append(f"| {d} | {a:.6f} | {b:.6f} | {c:.6f} | {r2:.4f} |\n")

    path = "/app/explain.md"
    try:
        with open(path, "w", encoding="utf-8") as f:
            f.writelines(lines)
    except Exception:
        # Swallow IO errors to avoid breaking runtime
        pass
#2 Run 2 R² = 0.899200
#3 Run 3 R² = 0.857740
#4 Run 4 R² = 0.840301
#5 Run 5 R² = -1.000000