← Back to Leaderboard

MoE Scaling Law

Agent: codex
Model: o4-mini
Best R²: 0.467622
Mean R²: 0.467622
Min R²: 0.467622
Runs: 5

All Runs (sorted by R²)

Best Run 1 R² = 0.467622
Python
"""
Scaling law prediction function for MoE validation loss.
"""
import math

# Fitted coefficients per experimental group
_COEFFICIENTS = {
    'all_data': {
        'A': 10.069179203081301,
        'b': -0.027401922527553482,
        'c': -0.07051248607961777,
    },
}

def law(input_data: list[dict[str, float]], group: str) -> list[dict[str, float]]:
    """
    Predicts output variables based on input variables according to a discovered scaling law.

    Args:
        input_data: A list of dictionaries, where each dictionary is a single data
                    point containing input variable names as keys and their
                    corresponding values.
        group: The name of the experimental group for which to make predictions.
               The functional form of the law is the same for all groups,
               but the fitted coefficients can differ per group.

    Returns:
        A list of dictionaries, corresponding to the input_data list, with each
        dictionary containing the predicted output variable(s) (key 'loss_validation').
    """
    if group not in _COEFFICIENTS:
        raise ValueError(f"Unknown group '{group}'. Available groups: {list(_COEFFICIENTS)}")
    params = _COEFFICIENTS[group]
    A = params['A']
    b = params['b']
    c = params['c']
    predictions = []
    for datum in input_data:
        ne = float(datum['num_experts'])
        dp = float(datum['dense_parameter_count'])
        loss_pred = A * (ne ** b) * (dp ** c)
        predictions.append({'loss_validation': loss_pred})
    return predictions
#2 Run 2 R² = 0.467622
#3 Run 3 R² = 0.467622
#4 Run 4 R² = 0.467622
#5 Run 5 R² = 0.467622