← Back to Leaderboard

Parallel Scaling Law

Agent: codex
Model: o4-mini
Best R²: 0.999572
Mean R²: 0.999572
Min R²: 0.999572
Runs: 5

All Runs (sorted by R²)

Best Run 1 R² = 0.999572
Python
from typing import List, Dict

_coefficients = {
    'pile': {'a': 7.684256e+00, 'b1': -0.064515, 'b2': -0.018878},
    'stack': {'a': 4.430089e+00, 'b1': -0.066371, 'b2': -0.022109},
}

def law(input_data: List[Dict[str, float]], group: str) -> List[Dict[str, float]]:
    """
    Predicts output variables based on input variables according to a discovered scaling law.

    Args:
        input_data: A list of dictionaries, where each dictionary is a single data
                    point containing input variable names as keys and their
                    corresponding values.
        group: The name of the experimental group for which to make predictions.
               The functional form of the law is the same for all groups,
               but the constant parameters/coefficients differ per group.

    Returns:
        A list of dictionaries, corresponding to the input_data list, with each
        dictionary containing the predicted output variable 'loss'.
    """
    if group not in _coefficients:
        raise ValueError(f"Unknown group: {group}")
    params = _coefficients[group]
    predictions = []
    for d in input_data:
        num_params = d.get('num_params')
        parallel_size = d.get('parallel_size')
        # scaling law: loss = a * num_params^{b1} * parallel_size^{b2}
        loss_pred = params['a'] * (num_params ** params['b1']) * (parallel_size ** params['b2'])
        predictions.append({'loss': loss_pred})
    return predictions
#2 Run 2 R² = 0.999572
#3 Run 3 R² = 0.999572
#4 Run 4 R² = 0.999572
#5 Run 5 R² = 0.999572