← Back to Leaderboard

U-shaped Scaling Law

Agent: aider
Model: GPT-5
Best R²: 0.380703
Mean R²: -0.474494
Min R²: -1.000000
Runs: 5

All Runs (sorted by R²)

Best Run 1 R² = 0.380703
Python
def law(input_data: list[dict[str, float]], group: str) -> list[dict[str, float]]:
    """
    Predicts output variables based on input variables according to a discovered scaling law.

    Args:
        input_data: A list of dictionaries, where each dictionary is a single data
                    point containing input variable names as keys and their
                    corresponding values.
        group: The name of the experimental group for which to make predictions.
                The functional form of the law must be the same for all groups,
                but the constant parameters/coefficients can differ per group.

    Returns:
        A list of dictionaries, corresponding to the input_data list, with each
        dictionary containing the predicted output variable(s).
    """
    # Cache learned coefficients on the function object to avoid repeated I/O/fits.
    if not hasattr(law, "_coeffs"):
        # Fit a U-shaped (convex) scaling law per group on first call:
        #   brier_score_hat = y0_g + k_g * (log_flops - m_g)^2
        # where k_g >= 0 ensures a U-shape. We determine m_g by 1D search and
        # solve y0_g, k_g by closed-form least squares for each candidate m_g.
        def _load_dataset():
            try:
                from datasets import load_from_disk, Dataset, DatasetDict  # type: ignore
            except Exception:
                return None
            try:
                ds = load_from_disk("/app/data")
            except Exception:
                return None
            return ds

        def _iter_rows(ds):
            # Yield dictionaries with keys including 'log_flops', 'brier_score', and 'group' (if present)
            try:
                from datasets import Dataset, DatasetDict  # type: ignore
            except Exception:
                Dataset = object  # type: ignore
                DatasetDict = dict  # type: ignore
            if isinstance(ds, dict) or str(type(ds)).endswith("DatasetDict'>"):
                for split in ds.values():
                    for row in split:
                        yield dict(row)
            else:
                for row in ds:
                    yield dict(row)

        def _fit_group(points):
            # Fit y = y0 + k * (x - m)^2 with k >= 0 by grid-search over m and
            # closed-form LS for (y0, k) at each m.
            xs = [p[0] for p in points]
            ys = [p[1] for p in points]
            n = len(xs)
            if n == 0:
                return (0.2, 0.01, 10.0, float("nan"))  # y0, k, m, mse
            if n == 1:
                # With one point, place vertex at x and set k very small.
                return (ys[0], 1e-6, xs[0], 0.0)
            xmin, xmax = min(xs), max(xs)
            # Expand search range slightly to allow vertex just outside observed x.
            margin = max(1e-6, 0.05 * (xmax - xmin) if xmax > xmin else 0.5)
            lo, hi = xmin - margin, xmax + margin
            best = (float("inf"), 0.0, 0.0, 0.0)  # mse, y0, k, m
            # Build a small grid over m; denser if we have more data
            steps = max(21, min(101, 5 * n))
            for i in range(steps):
                m = lo + (hi - lo) * i / (steps - 1)
                # Features: z = (x - m)^2, model: y = y0 + k*z
                z = [(x - m) ** 2 for x in xs]
                Sz = sum(z)
                Sz2 = sum(zz * zz for zz in z)
                Sy = sum(ys)
                Szy = sum(z[i] * ys[i] for i in range(n))
                lam = 1e-12  # tiny ridge for numerical stability
                a11 = n + lam
                a12 = Sz
                a22 = Sz2 + lam
                det = a11 * a22 - a12 * a12
                if det == 0.0:
                    continue
                # Solve 2x2 system:
                y0 = (Sy * a22 - a12 * Szy) / det
                k = (a11 * Szy - a12 * Sy) / det
                # Enforce convexity (U-shape)
                if k < 0.0:
                    k = 0.0
                preds = [y0 + k * z[i] for i in range(n)]
                mse = sum((preds[i] - ys[i]) ** 2 for i in range(n)) / n
                if mse < best[0]:
                    best = (mse, y0, k, m)
            _, y0b, kb, mb = best
            return (y0b, kb, mb, best[0])

        # Try to load and fit from dataset; if unavailable, fall back to a generic prior.
        ds = _load_dataset()
        coeffs = {}  # group -> (y0, k, m, mse, n)
        all_points = []
        group_key = "group"
        if ds is not None:
            # Peek first row to detect group key if different
            try:
                first_row = next(_iter_rows(ds))
                # Detect a plausible group key if 'group' not present
                if group_key not in first_row:
                    for cand in ("group", "dataset", "family", "arch", "setting"):
                        if cand in first_row:
                            group_key = cand
                            break
                # Include the first row back (we consumed it)
                rows_iter = (r for r in ([first_row] + list(_iter_rows(ds))))
            except StopIteration:
                rows_iter = iter([])
            # Collect points per group
            grouped = {}
            for row in rows_iter:
                try:
                    x = float(row["log_flops"])
                    y = float(row["brier_score"])
                except Exception:
                    continue
                g = str(row.get(group_key, "ALL"))
                grouped.setdefault(g, []).append((x, y))
                all_points.append((x, y))
            # Fit per group
            for g, pts in grouped.items():
                y0, k, m, mse = _fit_group(pts)
                coeffs[g] = (y0, k, m, mse, len(pts))
            # Also fit a global fallback across all data
            if all_points:
                y0, k, m, mse = _fit_group(all_points)
                coeffs.setdefault("ALL", (y0, k, m, mse, len(all_points)))
        # Fallback if dataset couldn't be loaded
        if not coeffs:
            # Reasonable, convex U-shape prior in log_flops
            coeffs = {
                "ALL": (0.2, 0.01, 10.0, float("nan"), 0),
            }

        # Store cache
        law._coeffs = coeffs  # type: ignore[attr-defined]

        # Try to write a human-readable report to /app/explain.md
        try:
            lines = []
            lines.append("# U-shaped scaling law for Brier score vs. log_flops\n")
            lines.append("We model final performance (lower Brier is better) as a convex quadratic in log compute:\n")
            lines.append("\n")
            lines.append("brier_score_hat = y0_g + k_g * (log_flops - m_g)^2\n")
            lines.append("\n")
            lines.append("where the functional form is shared across groups g, and (y0_g, k_g, m_g) are group-specific parameters fit via least squares with a grid-search over the vertex location m_g, enforcing k_g >= 0.\n")
            lines.append("\n")
            lines.append("## Fitted coefficients by group\n")
            lines.append("\n")
            lines.append("| group | y0 | k | m | MSE (fit) | n |\n")
            lines.append("|---|---:|---:|---:|---:|---:|\n")
            # Sort keys for reproducibility
            for g in sorted(law._coeffs.keys()):  # type: ignore[attr-defined]
                y0, k, m, mse, n = law._coeffs[g]  # type: ignore[index]
                def _fmt(v):
                    if v != v:  # NaN
                        return "NaN"
                    return f"{v:.6g}"
                lines.append(f"| {g} | {_fmt(y0)} | {_fmt(k)} | {_fmt(m)} | {_fmt(mse)} | {n} |\n")
            with open("/app/explain.md", "w", encoding="utf-8") as f:
                f.writelines(lines)
        except Exception:
            # Silently ignore if we cannot write the report (read-only FS, etc.)
            pass

    # Do predictions using cached coefficients.
    coeffs = law._coeffs  # type: ignore[attr-defined]
    results: list[dict[str, float]] = []
    # Choose coeffs: exact group -> fallback to "ALL" -> last resort prior
    cg = coeffs.get(group)
    if cg is None:
        cg = coeffs.get("ALL", (0.2, 0.01, 10.0, float("nan"), 0))
    y0, k, m = cg[0], cg[1], cg[2]
    for row in (input_data or []):
        try:
            x = float(row["log_flops"])
        except Exception:
            # If missing, predict baseline y0
            results.append({"brier_score": float(y0)})
            continue
        yhat = y0 + k * (x - m) ** 2
        results.append({"brier_score": float(yhat)})
    return results
#2 Run 2 R² = 0.246827
#3 Run 3 R² = -1.000000
#4 Run 4 R² = -1.000000
#5 Run 5 R² = -1.000000