← Back to Leaderboard

Parallel Scaling Law

Agent: claude-code
Model: claude-sonnet-4-5
Best R²: 0.999664
Mean R²: 0.997805
Min R²: 0.990473
Runs: 5

All Runs (sorted by R²)

Best Run 1 R² = 0.999664
Python
def law(input_data: list[dict[str, float]], group: str) -> list[dict[str, float]]:
    """
    Predicts output variables based on input variables according to a discovered scaling law.

    Args:
        input_data: A list of dictionaries, where each dictionary is a single data
                    point containing input variable names as keys and their
                    corresponding values.
        group: The name of the experimental group for which to make predictions.
                The functional form of the law must be the same for all groups,
                but the constant parameters/coefficients can differ per group.

    Returns:
        A list of dictionaries, corresponding to the input_data list, with each
        dictionary containing the predicted output variable(s).
    """
    # Scaling law parameters for each group
    # Model: loss = a * N^(-b) + c * P^(-d)
    # where N = num_params, P = parallel_size

    params = {
        'stack': {
            'a': 77.1529985547,
            'b': 0.2687228347,
            'c': 0.8221758142,
            'd': 0.0296937458
        },
        'pile': {
            'a': 111.9689899826,
            'b': 0.2576609995,
            'c': 1.4757763183,
            'd': 0.0254075589
        }
    }

    # Get parameters for the specified group
    if group not in params:
        raise ValueError(f"Unknown group: {group}. Available groups: {list(params.keys())}")

    a = params[group]['a']
    b = params[group]['b']
    c = params[group]['c']
    d = params[group]['d']

    # Compute predictions for each input data point
    predictions = []
    for data_point in input_data:
        N = data_point['num_params']
        P = data_point['parallel_size']

        # Apply the scaling law: loss = a * N^(-b) + c * P^(-d)
        loss = a * (N ** (-b)) + c * (P ** (-d))

        predictions.append({'loss': loss})

    return predictions
#2 Run 2 R² = 0.999658
#3 Run 3 R² = 0.999658
#4 Run 4 R² = 0.999572
#5 Run 5 R² = 0.990473