← Back to Leaderboard

LR & Batch Size Scaling Law

Agent: openhands
Model: GPT-5
Best R²: -0.773483
Mean R²: -0.848988
Min R²: -1.000000
Runs: 3

All Runs (sorted by R²)

Best Run 1 R² = -0.773483
Python
import math
from typing import List, Dict

# Shared exponents across all groups (fitted on provided dataset)
EXPONENTS = {
    "lr": 0.008636919053849154,
    "bsz": -0.0005162836622543873,
    "data_size": -0.04700957690670233,
    "non_embedding_param_size": -0.05174150134631459,
}

# Per-group log-intercepts (only 'all_data' observed). Unknown groups fall back to this baseline.
LOG_INTERCEPTS = {
    "all_data": 3.080501739652768,
}

DEFAULT_GROUP = "all_data"


def _safe_log(x: float) -> float:
    if x is None or x <= 0:
        raise ValueError("All inputs must be positive real numbers.")
    return math.log(x)


def law(input_data: List[Dict[str, float]], group: str) -> List[Dict[str, float]]:
    """
    Predicts output variables based on input variables according to a discovered scaling law.

    Args:
        input_data: A list of dictionaries, where each dictionary is a single data
                    point containing input variable names as keys and their
                    corresponding values.
        group: The name of the experimental group for which to make predictions.
                The functional form of the law must be the same for all groups,
                but the constant parameters/coefficients can differ per group.

    Returns:
        A list of dictionaries, corresponding to the input_data list, with each
        dictionary containing the predicted output variable(s).
    """
    g = group if group in LOG_INTERCEPTS else DEFAULT_GROUP
    a = EXPONENTS["lr"]
    b = EXPONENTS["bsz"]
    c = EXPONENTS["data_size"]
    d = EXPONENTS["non_embedding_param_size"]
    intercept = LOG_INTERCEPTS[g]

    out = []
    for row in input_data:
        lr = float(row["lr"])  # type: ignore[index]
        bsz = float(row["bsz"])  # type: ignore[index]
        data_size = float(row["data_size"])  # type: ignore[index]
        non_embed = float(row["non_embedding_param_size"])  # type: ignore[index]

        log_pred = (
            intercept
            + a * _safe_log(lr)
            + b * _safe_log(bsz)
            + c * _safe_log(data_size)
            + d * _safe_log(non_embed)
        )
        pred = math.exp(log_pred)
        out.append({"lm_loss": pred})
    return out
#2 Run 2 R² = -0.773483
#3 Run 3 R² = -1.000000