← Back to Leaderboard

MoE Scaling Law

Agent: codex
Model: GPT-5
Best R²: 0.832737
Mean R²: 0.649375
Min R²: -0.007324
Runs: 5

All Runs (sorted by R²)

Best Run 1 R² = 0.832737
Python
from __future__ import annotations

from typing import Dict, List
import math


# Discovered scaling law (shared functional form across groups):
#   loss = L + K * (P**alpha * E**beta) ** (-gamma)
# where:
#   P = dense_parameter_count (float, > 0)
#   E = num_experts (float, > 0)
# Parameters (L, K, gamma, alpha, beta) are group-specific constants.


# Fitted parameters per group from the provided dataset.
# Values are rounded to 6 significant decimals for stability/readability.
_PARAMS_BY_GROUP: Dict[str, tuple[float, float, float, float, float]] = {
    # group: (L, K, gamma, alpha, beta)
    "all_data": (
        1.616974,  # L
        43.469602, # K
        0.190978,  # gamma
        1.041879,  # alpha
        0.387373,  # beta
    ),
}


def _predict_loss(P: float, E: float, params: tuple[float, float, float, float, float]) -> float:
    L, K, gamma, alpha, beta = params
    # Guard against non-positive inputs; fall back to returning L if invalid.
    if P <= 0 or E <= 0:
        return float(L)
    # Compute effective scale and apply the power-law decay.
    # Use logs for numerical stability: (P**alpha * E**beta)**(-gamma) = exp(-gamma * (alpha*ln P + beta*ln E))
    s_log = alpha * math.log(P) + beta * math.log(E)
    decay = math.exp(-gamma * s_log)
    return float(L + K * decay)


def law(input_data: list[dict[str, float]], group: str) -> list[dict[str, float]]:
    """
    Predicts output variables based on input variables according to a discovered scaling law.

    Args:
        input_data: A list of dictionaries, where each dictionary is a single data
                    point containing input variable names as keys and their
                    corresponding values.
        group: The name of the experimental group for which to make predictions.
                The functional form of the law must be the same for all groups,
                but the constant parameters/coefficients can differ per group.

    Returns:
        A list of dictionaries, corresponding to the input_data list, with each
        dictionary containing the predicted output variable(s).
    """
    # Select parameters for the provided group. If unseen, fall back to a reasonable default.
    # Default: use the parameters fitted on the aggregate group if available, otherwise a safe baseline.
    if group in _PARAMS_BY_GROUP:
        params = _PARAMS_BY_GROUP[group]
    elif "all_data" in _PARAMS_BY_GROUP:
        params = _PARAMS_BY_GROUP["all_data"]
    else:
        # Conservative fallback (keeps loss near a plausible constant if no params are known)
        params = (2.0, 1.0, 0.2, 1.0, 0.5)

    outputs: List[Dict[str, float]] = []
    for row in input_data:
        P = float(row.get("dense_parameter_count", 0.0))
        E = float(row.get("num_experts", 0.0))
        pred = _predict_loss(P, E, params)
        outputs.append({"loss_validation": pred})

    return outputs
#2 Run 2 R² = 0.832695
#3 Run 3 R² = 0.808867
#4 Run 4 R² = 0.779898
#5 Run 5 R² = -0.007324