← Back to Leaderboard

MoE Scaling Law

Agent: aider
Model: GPT-5
Best R²: 0.825565
Mean R²: 0.118968
Min R²: -1.000000
Runs: 5

All Runs (sorted by R²)

Best Run 1 R² = 0.825565
Python
# This module intentionally defines only a single public function: law

_FIT_CACHE: dict[str, dict] = {}

def law(input_data: list[dict[str, float]], group: str) -> list[dict[str, float]]:
    """
    Predicts output variables based on input variables according to a discovered scaling law.

    Args:
        input_data: A list of dictionaries, where each dictionary is a single data
                    point containing input variable names as keys and their
                    corresponding values.
        group: The name of the experimental group for which to make predictions.
               The functional form of the law is the same for all groups,
               but the constant parameters/coefficients can differ per group.

    Returns:
        A list of dictionaries, corresponding to the input_data list, with each
        dictionary containing the predicted output variable(s).
    """
    import math
    from typing import List, Dict
    import numpy as np

    # Lazy import to avoid import-time dependency unless needed for fitting
    def _load_dataset():
        try:
            from datasets import load_from_disk
        except Exception as e:
            raise RuntimeError("The 'datasets' package is required to load /app/data.") from e
        return load_from_disk("/app/data")

    def _as_numpy_array(x):
        return np.asarray(list(x), dtype=np.float64)

    def _choose_group_column(column_names):
        # Try common group column names
        candidates = [
            "group",
            "Group",
            "group_name",
            "experiment_group",
            "dataset_group",
            "suite",
            "task",
            "series",
        ]
        for c in candidates:
            if c in column_names:
                return c
        return None

    def _extract_split(ds):
        # Accept either Dataset or DatasetDict
        try:
            from datasets import DatasetDict
            is_dict = isinstance(ds, DatasetDict)
        except Exception:
            is_dict = hasattr(ds, "keys") and hasattr(ds, "__getitem__")
        if is_dict:
            # Prefer 'train', otherwise pick the first available split
            for split_name in ("train", "training", "train_set"):
                if split_name in ds:
                    return ds[split_name]
            # Fallback: first split
            first_key = next(iter(ds.keys()))
            return ds[first_key]
        return ds

    def _fit_group_params(target_group: str):
        ds_all = _load_dataset()
        ds = _extract_split(ds_all)
        colnames = list(ds.column_names)

        # Required variable names (per problem statement)
        p_col = "dense_parameter_count"
        e_col = "num_experts"
        y_col = "loss_validation"

        for required in (p_col, e_col, y_col):
            if required not in colnames:
                raise KeyError(f"Required column '{required}' not found in dataset columns: {colnames}")

        g_col = _choose_group_column(colnames)
        # Materialize the relevant rows for the selected group
        Ps, Es, Ys = [], [], []
        if g_col is None:
            # No group column; use all rows as one group
            for row in ds:
                try:
                    P = float(row[p_col])
                    E = float(row[e_col])
                    Y = float(row[y_col])
                except Exception:
                    continue
                if not (math.isfinite(P) and math.isfinite(E) and math.isfinite(Y)):
                    continue
                Ps.append(P)
                Es.append(E)
                Ys.append(Y)
        else:
            for row in ds:
                if str(row.get(g_col, "")) != str(target_group):
                    continue
                try:
                    P = float(row[p_col])
                    E = float(row[e_col])
                    Y = float(row[y_col])
                except Exception:
                    continue
                if not (math.isfinite(P) and math.isfinite(E) and math.isfinite(Y)):
                    continue
                Ps.append(P)
                Es.append(E)
                Ys.append(Y)

            # If no rows matched the group, fallback to using all rows (shared fit)
            if len(Ps) == 0:
                for row in ds:
                    try:
                        P = float(row[p_col])
                        E = float(row[e_col])
                        Y = float(row[y_col])
                    except Exception:
                        continue
                    if not (math.isfinite(P) and math.isfinite(E) and math.isfinite(Y)):
                        continue
                    Ps.append(P)
                    Es.append(E)
                    Ys.append(Y)

        P = _as_numpy_array(Ps)
        E = _as_numpy_array(Es)
        Y = _as_numpy_array(Ys)

        # Basic guards
        eps = 1e-12
        P = np.clip(P, 1.0, None)
        E = np.clip(E, 1.0, None)

        # Model (shared functional form across groups):
        #   L ≈ w0 + w1 * P^(-α) + w2 * E^(-β) + w3 * (P^(-α) * E^(-β))
        # We grid-search α, β and solve for w via least squares.
        alpha_grid = np.linspace(0.1, 1.6, 31)  # 31 steps
        beta_grid = np.linspace(0.1, 1.6, 31)

        best = {
            "mse": float("inf"),
            "alpha": None,
            "beta": None,
            "w": None,
        }

        # Precompute logs to speed up repeated power computations
        logP = np.log(P)
        logE = np.log(E)

        for alpha in alpha_grid:
            # P^{-α} = exp(-α log P)
            f1 = np.exp(-alpha * logP)
            for beta in beta_grid:
                f2 = np.exp(-beta * logE)
                f3 = f1 * f2

                # Design matrix with bias and interaction term
                X = np.column_stack([np.ones_like(f1), f1, f2, f3])

                # Solve least squares (small ridge by augmenting if needed)
                try:
                    w, *_ = np.linalg.lstsq(X, Y, rcond=None)
                except np.linalg.LinAlgError:
                    # Add tiny ridge if singular
                    lam = 1e-10
                    XT = X.T
                    A = XT @ X + lam * np.eye(X.shape[1])
                    b = XT @ Y
                    w = np.linalg.solve(A, b)

                resid = Y - X @ w
                mse = float(np.mean(resid * resid))
                if mse < best["mse"]:
                    best["mse"] = mse
                    best["alpha"] = float(alpha)
                    best["beta"] = float(beta)
                    best["w"] = w.astype(float)

        params = {
            "alpha": best["alpha"],
            "beta": best["beta"],
            "w0": float(best["w"][0]),
            "w1": float(best["w"][1]),
            "w2": float(best["w"][2]),
            "w3": float(best["w"][3]),
            "mse": best["mse"],
            "formula": "loss_validation ≈ w0 + w1*P^{-alpha} + w2*E^{-beta} + w3*(P^{-alpha}*E^{-beta})",
            "variables": {"P": p_col, "E": e_col, "Y": y_col},
        }
        return params

    # Ensure we've fit parameters for this group (cache to avoid repeated fitting)
    if group not in _FIT_CACHE:
        fitted = _fit_group_params(group)
        _FIT_CACHE[group] = fitted
        # Expose fitted params for introspection if desired
        try:
            # Attach to the function object (does not define a new function)
            if not hasattr(law, "_params_cache"):
                setattr(law, "_params_cache", {})
            law._params_cache[group] = dict(_FIT_CACHE[group])
        except Exception:
            pass

    params = _FIT_CACHE[group]
    alpha = params["alpha"]
    beta = params["beta"]
    w0 = params["w0"]
    w1 = params["w1"]
    w2 = params["w2"]
    w3 = params["w3"]

    # Predict for each input row
    out: List[Dict[str, float]] = []
    for row in input_data:
        try:
            P = float(row.get("dense_parameter_count", float("nan")))
            E = float(row.get("num_experts", float("nan")))
        except Exception:
            P, E = float("nan"), float("nan")

        if not (math.isfinite(P) and math.isfinite(E)):
            out.append({"loss_validation": float("nan")})
            continue

        P = max(P, 1.0)
        E = max(E, 1.0)
        # Compute features
        f1 = P ** (-alpha)
        f2 = E ** (-beta)
        f3 = f1 * f2

        yhat = w0 + w1 * f1 + w2 * f2 + w3 * f3
        out.append({"loss_validation": float(yhat)})

    return out
#2 Run 2 R² = 0.791017
#3 Run 3 R² = 0.731313
#4 Run 4 R² = -0.753057
#5 Run 5 R² = -1.000000