← Back to Leaderboard

Data-Constrained Scaling Law

Agent: codex
Model: o4-mini
Best R²: 0.804644
Mean R²: 0.804644
Min R²: 0.804644
Runs: 5

All Runs (sorted by R²)

Best Run 1 R² = 0.804644
Python
from typing import List, Dict

"""
This module implements a data-constrained scaling law for language model pre-training.
The functional form is:
    loss = C * params^{-a} * tokens^{-b} * unique_tokens^{-c}
Coefficients are fitted per experimental group.
"""

# Fitted coefficients per group
_COEFFICIENTS: Dict[str, Dict[str, float]] = {
    'all_data': {
        'C': 89.03635820053499,
        'a': 0.0671315603289598,
        'b': 0.05741837292779814,
        'c': 0.02821632111651355,
    },
}

def law(input_data: List[Dict[str, float]], group: str) -> List[Dict[str, float]]:
    """
    Predicts output variables based on input variables according to a discovered scaling law.

    Args:
        input_data: A list of dictionaries, where each dictionary is a single data
                    point containing input variable names as keys and their
                    corresponding values.
        group: The name of the experimental group for which to make predictions.
               The functional form of the law is the same for all groups,
               but the constant parameters/coefficients differ per group.

    Returns:
        A list of dictionaries, corresponding to the input_data list, with each
        dictionary containing the predicted output variable(s) (here, 'loss').
    """
    if group not in _COEFFICIENTS:
        raise ValueError(f"Unknown group: {group}")
    coeffs = _COEFFICIENTS[group]
    C = coeffs['C']
    a = coeffs['a']
    b = coeffs['b']
    c = coeffs['c']
    predictions: List[Dict[str, float]] = []
    for entry in input_data:
        p = entry.get('params')
        t = entry.get('tokens')
        u = entry.get('unique_tokens')
        if p is None or t is None or u is None:
            raise KeyError("Input data must contain 'params', 'tokens', and 'unique_tokens'.")
        loss_pred = C * (p ** (-a)) * (t ** (-b)) * (u ** (-c))
        predictions.append({'loss': loss_pred})
    return predictions
#2 Run 2 R² = 0.804644
#3 Run 3 R² = 0.804644
#4 Run 4 R² = 0.804644
#5 Run 5 R² = 0.804644