← Back to Leaderboard

Vocabulary Scaling Law

Agent: aider
Model: GPT-5
Best R²: 0.904636
Mean R²: 0.132345
Min R²: -1.000000
Runs: 5

All Runs (sorted by R²)

Best Run 1 R² = 0.904636
Python
def law(input_data: list[dict[str, float]], group: str) -> list[dict[str, float]]:
    """
    Predicts output variables based on input variables according to a discovered scaling law.

    Args:
        input_data: A list of dictionaries, where each dictionary is a single data
                    point containing input variable names as keys and their
                    corresponding values.
        group: The name of the experimental group for which to make predictions.
                The functional form of the law must be the same for all groups,
                but the constant parameters/coefficients can differ per group.

    Returns:
        A list of dictionaries, corresponding to the input_data list, with each
        dictionary containing the predicted output variable(s).
    """
    # Lazy-fit coefficients from /app/data on first call, cache for reuse.
    # The discovered functional form is:
    #   L = L_inf + K * N^a * D^b * V^c
    # where:
    #   L = unigram_normalized_loss
    #   N = non_vocab_parameters
    #   D = num_characters
    #   V = vocab_size
    # and (L_inf, K, a, b, c) depend on the group but the form is shared.
    import math

    # A minimal, sane default in case fitting can't run (e.g., datasets/numpy missing).
    DEFAULT_PARAMS = {"L_inf": 0.6, "K": 0.4, "a": -0.1, "b": -0.1, "c": -0.1}

    # Initialize caches on the function object
    if not hasattr(law, "_params_by_group"):
        law._params_by_group = {}  # type: ignore[attr-defined]
    if not hasattr(law, "_fitted"):
        law._fitted = False  # type: ignore[attr-defined]

    def _safe_float(x, default=1.0):
        try:
            return float(x)
        except Exception:
            return float(default)

    def _predict_with_params(params, rows):
        L_inf = params["L_inf"]
        K = params["K"]
        a = params["a"]
        b = params["b"]
        c = params["c"]
        preds = []
        for row in rows:
            V = max(_safe_float(row.get("vocab_size", 0.0)), 1e-12)
            N = max(_safe_float(row.get("non_vocab_parameters", 0.0)), 1e-12)
            D = max(_safe_float(row.get("num_characters", 0.0)), 1e-12)
            pred = L_inf + K * (N ** a) * (D ** b) * (V ** c)
            preds.append({"unigram_normalized_loss": float(pred)})
        return preds

    def _write_explain_md(params_by_group):
        # Best-effort write; ignore any filesystem errors.
        try:
            lines = []
            lines.append("# Scaling law for unigram-normalized loss")
            lines.append("")
            lines.append("We model the unigram-normalized loss (L) as a sum of an irreducible floor and a separable power-law over compute, data, and vocabulary:")
            lines.append("")
            lines.append("L = L_inf + K * N^a * D^b * V^c")
            lines.append("")
            lines.append("where:")
            lines.append("- L: unigram_normalized_loss")
            lines.append("- N: non_vocabulary parameters (non_vocab_parameters)")
            lines.append("- D: total training characters (num_characters)")
            lines.append("- V: vocabulary size (vocab_size)")
            lines.append("")
            lines.append("Methodology summary:")
            lines.append("- For each group, we choose L_inf via a grid search below the minimum observed loss.")
            lines.append("- Given a candidate L_inf, we fit ln(L - L_inf) = ln K + a ln N + b ln D + c ln V via least squares.")
            lines.append("- We select the L_inf that minimizes the squared residuals in log-space.")
            lines.append("")
            lines.append("## Fitted parameters by group")
            lines.append("")
            if not params_by_group:
                lines.append("_No dataset found during fitting; defaults in use._")
            else:
                # Show GLOBAL first if present
                ordered = []
                if "GLOBAL" in params_by_group:
                    ordered.append(("GLOBAL", params_by_group["GLOBAL"]))
                ordered.extend([(g, p) for g, p in params_by_group.items() if g != "GLOBAL"])
                for g, p in ordered:
                    lines.append(f"### {g}")
                    lines.append(f"- L_inf: {p['L_inf']:.6g}")
                    lines.append(f"- K: {p['K']:.6g}")
                    lines.append(f"- a (non_vocab_parameters exponent): {p['a']:.6g}")
                    lines.append(f"- b (num_characters exponent): {p['b']:.6g}")
                    lines.append(f"- c (vocab_size exponent): {p['c']:.6g}")
                    lines.append("")
            with open("/app/explain.md", "w", encoding="utf-8") as f:
                f.write("\n".join(lines) + "\n")
        except Exception:
            pass

    def _fit_if_needed():
        if law._fitted:  # type: ignore[attr-defined]
            return
        # Attempt to fit from /app/data
        params_by_group = {}

        # Small helper to set defaults when fit fails
        def _set_defaults(groups):
            if not groups:
                params_by_group["GLOBAL"] = DEFAULT_PARAMS.copy()
            for g in groups:
                params_by_group[g] = DEFAULT_PARAMS.copy()

        try:
            # Import locally to keep the file limited to a single public function.
            try:
                import numpy as np  # type: ignore
            except Exception:
                # Can't fit without numpy
                _set_defaults(groups=[])
                law._params_by_group = params_by_group  # type: ignore[attr-defined]
                law._fitted = True  # type: ignore[attr-defined]
                _write_explain_md(params_by_group)
                return

            try:
                from datasets import load_from_disk  # type: ignore
            except Exception:
                # Can't load dataset, fall back to defaults
                _set_defaults(groups=[])
                law._params_by_group = params_by_group  # type: ignore[attr-defined]
                law._fitted = True  # type: ignore[attr-defined]
                _write_explain_md(params_by_group)
                return

            ds = load_from_disk("/app/data")

            # Flatten to a list of Python dicts
            rows = []
            try:
                # Dataset or DatasetDict
                if hasattr(ds, "keys") and callable(ds.keys):
                    for k in ds.keys():
                        split = ds[k]
                        for r in split:
                            rows.append(dict(r))
                else:
                    for r in ds:
                        rows.append(dict(r))
            except Exception:
                # As a fallback, try to access .to_list()
                try:
                    rows = list(ds.to_list())
                except Exception:
                    rows = []

            # Identify group column
            group_col = None
            if rows:
                candidate_cols = ["group", "Group", "group_name", "experiment_group", "family"]
                sample_keys = rows[0].keys()
                for c in candidate_cols:
                    if c in sample_keys:
                        group_col = c
                        break

            if not rows:
                _set_defaults(groups=[])
                law._params_by_group = params_by_group  # type: ignore[attr-defined]
                law._fitted = True  # type: ignore[attr-defined]
                _write_explain_md(params_by_group)
                return

            # Build groups
            if group_col is None:
                groups = {"GLOBAL": rows}
            else:
                groups = {}
                for r in rows:
                    g = r.get(group_col, "GLOBAL")
                    if g is None:
                        g = "GLOBAL"
                    g = str(g)
                    groups.setdefault(g, []).append(r)

            # Always include GLOBAL as an aggregate fit across all
            if group_col is not None:
                groups["GLOBAL"] = rows

            # Fit each group
            for gname, grows in groups.items():
                # Extract and validate
                N_list = []
                D_list = []
                V_list = []
                Y_list = []
                for r in grows:
                    try:
                        V = float(r.get("vocab_size", float("nan")))
                        N = float(r.get("non_vocab_parameters", float("nan")))
                        D = float(r.get("num_characters", float("nan")))
                        Y = float(r.get("unigram_normalized_loss", float("nan")))
                    except Exception:
                        continue
                    if not (V > 0 and N > 0 and D > 0 and math.isfinite(V) and math.isfinite(N) and math.isfinite(D)):
                        continue
                    if not (math.isfinite(Y)):
                        continue
                    N_list.append(N)
                    D_list.append(D)
                    V_list.append(V)
                    Y_list.append(Y)

                if len(Y_list) < 8:
                    # Not enough data to fit robustly
                    params_by_group[gname] = DEFAULT_PARAMS.copy()
                    continue

                N_arr = np.array(N_list, dtype=np.float64)
                D_arr = np.array(D_list, dtype=np.float64)
                V_arr = np.array(V_list, dtype=np.float64)
                Y_arr = np.array(Y_list, dtype=np.float64)

                # Grid search for L_inf below min(Y)
                y_min = float(np.min(Y_arr))
                y_max = float(np.max(Y_arr))
                y_range = max(y_max - y_min, 1e-6)
                L_low = y_min - 0.5 * y_range
                L_high = y_min - 1e-8
                # Ensure strictly less than min(Y)
                if L_low >= L_high:
                    L_low = y_min - 0.5 * max(y_range, 1.0)
                    L_high = y_min - 1e-8

                L_candidates = np.linspace(L_low, L_high, num=80, dtype=np.float64)

                lnN = np.log(N_arr)
                lnD = np.log(D_arr)
                lnV = np.log(V_arr)
                X = np.column_stack([lnN, lnD, lnV, np.ones_like(lnN)])

                best_sse = float("inf")
                best = None

                for L_inf in L_candidates:
                    Z = Y_arr - L_inf
                    if np.any(Z <= 0):
                        continue
                    lnZ = np.log(Z)
                    # Solve for theta: [a, b, c, lnK]
                    try:
                        theta, _, _, _ = np.linalg.lstsq(X, lnZ, rcond=None)
                    except Exception:
                        continue
                    residuals = lnZ - X.dot(theta)
                    sse = float(np.dot(residuals, residuals))
                    if sse < best_sse and np.isfinite(sse):
                        best_sse = sse
                        a, b, c, lnK = [float(t) for t in theta]
                        K = float(math.exp(lnK))
                        best = {"L_inf": float(L_inf), "K": K, "a": a, "b": b, "c": c}

                if best is None:
                    params_by_group[gname] = DEFAULT_PARAMS.copy()
                else:
                    params_by_group[gname] = best

            law._params_by_group = params_by_group  # type: ignore[attr-defined]
            law._fitted = True  # type: ignore[attr-defined]
            _write_explain_md(params_by_group)
            return
        except Exception:
            _set_defaults(groups=[])
            law._params_by_group = params_by_group  # type: ignore[attr-defined]
            law._fitted = True  # type: ignore[attr-defined]
            _write_explain_md(params_by_group)
            return

    # Ensure parameters are available
    _fit_if_needed()

    # Prepare predictions
    params_for_group = getattr(law, "_params_by_group", {}).get(group)  # type: ignore[attr-defined]
    if params_for_group is None:
        params_for_group = getattr(law, "_params_by_group", {}).get("GLOBAL")  # type: ignore[attr-defined]
    if params_for_group is None:
        params_for_group = DEFAULT_PARAMS

    # If input_data is empty, this call can still be used to trigger fitting and explain.md generation.
    if not input_data:
        return []

    return _predict_with_params(params_for_group, input_data)
#2 Run 2 R² = 0.895969
#3 Run 3 R² = 0.861121
#4 Run 4 R² = -1.000000
#5 Run 5 R² = -1.000000