← Back to Leaderboard

LR & Batch Size Scaling Law

Agent: claude-code
Model: claude-haiku-4-5
Best R²: 0.763137
Mean R²: -0.511462
Min R²: -1.000000
Runs: 5

All Runs (sorted by R²)

Best Run 1 R² = 0.763137
Python
import numpy as np


def law(input_data: list[dict[str, float]], group: str) -> list[dict[str, float]]:
    """
    Predicts output variables based on input variables according to a discovered scaling law.

    The underlying model uses a polynomial degree-2 transformation in log-space:
    log(lm_loss) = intercept + sum of linear and quadratic terms in log-space features

    Args:
        input_data: A list of dictionaries, where each dictionary is a single data
                    point containing input variable names as keys and their
                    corresponding values.
        group: The name of the experimental group for which to make predictions.
                The functional form of the law must be the same for all groups,
                but the constant parameters/coefficients can differ per group.

    Returns:
        A list of dictionaries, corresponding to the input_data list, with each
        dictionary containing the predicted output variable(s).
    """

    # Coefficients from polynomial degree-2 fit in log-space
    # These were derived from linear regression on log-transformed features
    intercept = 4.074148228884797

    coefficients = {
        'log_lr': 0.013795306610031,
        'log_bsz': 0.139224299881115,
        'log_data_size': -0.243356715515550,
        'log_non_embedding_param_size': 0.043573332851390,
        'log_lr^2': 0.011119851824430,
        'log_lr log_bsz': -0.006260814764152,
        'log_lr log_data_size': -0.001395292150337,
        'log_lr log_non_embedding_param_size': 0.010231103653946,
        'log_bsz^2': 0.009278590376023,
        'log_bsz log_data_size': -0.008906902516425,
        'log_bsz log_non_embedding_param_size': -0.003417998007062,
        'log_data_size^2': 0.008885626075669,
        'log_data_size log_non_embedding_param_size': -0.009360216068387,
        'log_non_embedding_param_size^2': 0.005268771454322,
    }

    results = []

    for data_point in input_data:
        # Extract input variables
        lr = data_point['lr']
        bsz = data_point['bsz']
        data_size = data_point['data_size']
        non_embedding_param_size = data_point['non_embedding_param_size']

        # Transform to log space
        log_lr = np.log(lr)
        log_bsz = np.log(bsz)
        log_data_size = np.log(data_size)
        log_non_embedding_param_size = np.log(non_embedding_param_size)

        # Compute log-loss using polynomial model
        log_lm_loss = intercept
        log_lm_loss += coefficients['log_lr'] * log_lr
        log_lm_loss += coefficients['log_bsz'] * log_bsz
        log_lm_loss += coefficients['log_data_size'] * log_data_size
        log_lm_loss += coefficients['log_non_embedding_param_size'] * log_non_embedding_param_size
        log_lm_loss += coefficients['log_lr^2'] * (log_lr ** 2)
        log_lm_loss += coefficients['log_lr log_bsz'] * (log_lr * log_bsz)
        log_lm_loss += coefficients['log_lr log_data_size'] * (log_lr * log_data_size)
        log_lm_loss += coefficients['log_lr log_non_embedding_param_size'] * (log_lr * log_non_embedding_param_size)
        log_lm_loss += coefficients['log_bsz^2'] * (log_bsz ** 2)
        log_lm_loss += coefficients['log_bsz log_data_size'] * (log_bsz * log_data_size)
        log_lm_loss += coefficients['log_bsz log_non_embedding_param_size'] * (log_bsz * log_non_embedding_param_size)
        log_lm_loss += coefficients['log_data_size^2'] * (log_data_size ** 2)
        log_lm_loss += coefficients['log_data_size log_non_embedding_param_size'] * (log_data_size * log_non_embedding_param_size)
        log_lm_loss += coefficients['log_non_embedding_param_size^2'] * (log_non_embedding_param_size ** 2)

        # Transform back to original space
        lm_loss = np.exp(log_lm_loss)

        results.append({'lm_loss': float(lm_loss)})

    return results
#2 Run 2 R² = -0.773483
#3 Run 3 R² = -0.773483
#4 Run 4 R² = -0.773483
#5 Run 5 R² = -1.000000