← Back to Leaderboard

LR & Batch Size Scaling Law

Agent: aider
Model: GPT-5
Best R²: 0.025235
Mean R²: -0.659042
Min R²: -1.000000
Runs: 5

All Runs (sorted by R²)

Best Run 1 R² = 0.025235
Python
from __future__ import annotations

import json
import math
import os
from typing import Dict, List

import numpy as np

# Attempt to import datasets lazily. If unavailable, we can still run with a fallback.
try:
    from datasets import load_from_disk  # type: ignore
except Exception:  # pragma: no cover
    load_from_disk = None  # type: ignore


# Global container for fitted coefficients, filled at import-time (lazily) from /app/data if available.
# Keys are group names; values are dicts with the coefficients.
COEFFS: Dict[str, Dict[str, float]] = {}
GLOBAL_GROUP_KEY = "__ALL__"

# Numerical safety epsilon for logs
_EPS = 1e-12

# Description of the functional form
FORMULA_DESC = (
    "log(lm_loss) = beta0_g + a_g*log(lr) + a2_g*(log(lr))^2 + "
    "b_g*log(bsz) + c_g*log(data_size) + d_g*log(non_embedding_param_size)\n"
    "=> lm_loss = exp(beta0_g + a_g*log(lr) + a2_g*(log(lr))^2 + "
    "b_g*log(bsz) + c_g*log(data_size) + d_g*log(non_embedding_param_size))"
)


def _safe_log(x: float) -> float:
    return math.log(max(float(x), _EPS))


def _design_row(lr: float, bsz: float, data_size: float, non_emb_params: float) -> np.ndarray:
    """
    Build a single feature row for the regression:
    [1, log(lr), (log(lr))^2, log(bsz), log(data_size), log(non_embedding_param_size)]
    """
    z_lr = _safe_log(lr)
    return np.array(
        [
            1.0,
            z_lr,
            z_lr * z_lr,
            _safe_log(bsz),
            _safe_log(data_size),
            _safe_log(non_emb_params),
        ],
        dtype=np.float64,
    )


def _fit_group(X: np.ndarray, y: np.ndarray, lam: float = 1e-6) -> np.ndarray:
    """
    Ridge-regularized least squares:
        (X^T X + lam I) w = X^T y
    """
    XT = X.T
    A = XT @ X
    # Ridge on all parameters including bias (small lam)
    A[np.diag_indices_from(A)] += lam
    b = XT @ y
    w = np.linalg.solve(A, b)
    return w


def _extract_dataset_rows(ds_item: dict) -> tuple[float, float, float, float, float, str | None]:
    """
    Extract lr, bsz, data_size, non_embedding_param_size, lm_loss, group (if present) from a dataset item.
    Returns tuple: (lr, bsz, data_size, non_emb_params, lm_loss, group)
    """
    lr = float(ds_item.get("lr"))
    bsz = float(ds_item.get("bsz"))
    data_size = float(ds_item.get("data_size"))
    non_emb = float(ds_item.get("non_embedding_param_size"))
    lm_loss = float(ds_item.get("lm_loss"))
    group = ds_item.get("group")
    if group is not None:
        group = str(group)
    return lr, bsz, data_size, non_emb, lm_loss, group


def _load_and_fit(path: str = "/app/data") -> Dict[str, Dict[str, float]]:
    """
    Load dataset from disk and fit per-group coefficients according to FORMULA_DESC.
    If datasets API is not available or loading fails, return a robust default.
    """
    coeffs: Dict[str, Dict[str, float]] = {}

    if load_from_disk is None:
        # Fallback: very conservative defaults (weak dependence)
        coeffs[GLOBAL_GROUP_KEY] = {
            "beta0": 1.0,
            "a_lr": 0.0,
            "a2_lr2": 0.1,
            "b_bsz": -0.02,
            "c_data": -0.1,
            "d_param": -0.1,
        }
        return coeffs

    # Load dataset (can be Dataset or DatasetDict)
    try:
        ds = load_from_disk(path)
    except Exception:
        # Fallback defaults if loading fails
        coeffs[GLOBAL_GROUP_KEY] = {
            "beta0": 1.0,
            "a_lr": 0.0,
            "a2_lr2": 0.1,
            "b_bsz": -0.02,
            "c_data": -0.1,
            "d_param": -0.1,
        }
        return coeffs

    # Collect all rows across splits if needed
    rows = []
    if hasattr(ds, "values"):  # DatasetDict
        for split in ds.values():
            rows.extend(list(split))
    else:  # Single Dataset
        rows = list(ds)

    # Partition by group (or GLOBAL group if group missing)
    groups: Dict[str, list[tuple[float, float, float, float, float]]] = {}
    for it in rows:
        try:
            lr, bsz, data_size, non_emb, lm_loss, group = _extract_dataset_rows(it)
        except Exception:
            continue

        # Filter invalid values
        if not all(v is not None for v in (lr, bsz, data_size, non_emb, lm_loss)):
            continue
        if lr <= 0 or bsz <= 0 or data_size <= 0 or non_emb <= 0 or lm_loss <= 0:
            continue

        gname = group if group is not None else GLOBAL_GROUP_KEY
        groups.setdefault(gname, []).append((lr, bsz, data_size, non_emb, lm_loss))

    # If no groups found, bail to fallback
    if not groups:
        coeffs[GLOBAL_GROUP_KEY] = {
            "beta0": 1.0,
            "a_lr": 0.0,
            "a2_lr2": 0.1,
            "b_bsz": -0.02,
            "c_data": -0.1,
            "d_param": -0.1,
        }
        return coeffs

    # Also fit a global group across all data to use as fallback for unknown groups
    all_data = [rec for glist in groups.values() for rec in glist]
    groups_with_global = dict(groups)
    groups_with_global[GLOBAL_GROUP_KEY] = all_data

    # Fit per group
    for gname, glist in groups_with_global.items():
        if len(glist) < 6:  # Need at least as many points as parameters for a good fit
            continue
        X = np.vstack([_design_row(*rec[:4]) for rec in glist])  # n x 6
        y = np.array([_safe_log(rec[4]) for rec in glist], dtype=np.float64)  # log(lm_loss)

        try:
            w = _fit_group(X, y, lam=1e-6)
        except np.linalg.LinAlgError:
            # Very small increase in regularization if ill-conditioned
            w = _fit_group(X, y, lam=1e-3)

        coeffs[gname] = {
            "beta0": float(w[0]),
            "a_lr": float(w[1]),
            "a2_lr2": float(w[2]),
            "b_bsz": float(w[3]),
            "c_data": float(w[4]),
            "d_param": float(w[5]),
        }

    # In rare case fitting failed for some groups, ensure we at least have a global fallback
    if GLOBAL_GROUP_KEY not in coeffs:
        # Fit a quick global from whatever we have (if any), else use defaults
        if all_data:
            X = np.vstack([_design_row(*rec[:4]) for rec in all_data])
            y = np.array([_safe_log(rec[4]) for rec in all_data], dtype=np.float64)
            try:
                w = _fit_group(X, y, lam=1e-6)
            except np.linalg.LinAlgError:
                w = _fit_group(X, y, lam=1e-3)
            coeffs[GLOBAL_GROUP_KEY] = {
                "beta0": float(w[0]),
                "a_lr": float(w[1]),
                "a2_lr2": float(w[2]),
                "b_bsz": float(w[3]),
                "c_data": float(w[4]),
                "d_param": float(w[5]),
            }
        else:
            coeffs[GLOBAL_GROUP_KEY] = {
                "beta0": 1.0,
                "a_lr": 0.0,
                "a2_lr2": 0.1,
                "b_bsz": -0.02,
                "c_data": -0.1,
                "d_param": -0.1,
            }

    return coeffs


def _write_explain_md(coeffs: Dict[str, Dict[str, float]], path: str = "/app/explain.md") -> None:
    """
    Generate a detailed explanation file including the functional form and fitted coefficients.
    """
    lines: List[str] = []
    lines.append("# Scaling Law for Final Language Modeling Loss\n")
    lines.append("This document describes the discovered scaling law relating the final language modeling loss (lm_loss) to training hyperparameters.\n")
    lines.append("## Functional Form\n")
    lines.append("We fit a log-linear model with a quadratic term in log(learning rate) to capture the typical U-shaped dependence on learning rate:\n")
    lines.append("log(lm_loss) = beta0_g + a_g*log(lr) + a2_g*(log(lr))^2 + b_g*log(bsz) + c_g*log(data_size) + d_g*log(non_embedding_param_size)\n")
    lines.append("\nEquivalently:\n")
    lines.append("lm_loss = exp(beta0_g + a_g*log(lr) + a2_g*(log(lr))^2 + b_g*log(bsz) + c_g*log(data_size) + d_g*log(non_embedding_param_size))\n")
    lines.append("\n- g denotes the experimental group. The functional form is identical across groups, while coefficients vary per group.\n")
    lines.append("\n## Methodology\n")
    lines.append("- Loaded the dataset from `/app/data` using `datasets.load_from_disk()`.\n")
    lines.append("- Filtered rows to ensure all variables are positive (required for logarithms).\n")
    lines.append("- Regressed log(lm_loss) on [1, log(lr), (log(lr))^2, log(bsz), log(data_size), log(non_embedding_param_size)] using ridge-regularized least squares (λ = 1e-6).\n")
    lines.append("- Fitted the model per group and also a global model across all data as a fallback.\n")
    lines.append("\n## Fitted Coefficients by Group\n")
    lines.append("The following coefficients were fitted programmatically at import time of `law.py`:\n")
    lines.append("\n")
    # Nicely format coefficients per group
    # Sort groups, showing GLOBAL first if present
    keys = list(coeffs.keys())
    if GLOBAL_GROUP_KEY in keys:
        keys.remove(GLOBAL_GROUP_KEY)
        keys = [GLOBAL_GROUP_KEY] + sorted(keys)
    else:
        keys = sorted(keys)
    for g in keys:
        c = coeffs[g]
        lines.append(f"### Group: {g}\n")
        lines.append(f"- beta0: {c['beta0']:.8f}\n")
        lines.append(f"- a (log lr): {c['a_lr']:.8f}\n")
        lines.append(f"- a2 (log lr)^2: {c['a2_lr2']:.8f}\n")
        lines.append(f"- b (log bsz): {c['b_bsz']:.8f}\n")
        lines.append(f"- c (log data_size): {c['c_data']:.8f}\n")
        lines.append(f"- d (log non_embedding_param_size): {c['d_param']:.8f}\n")
        lines.append("\n")

    try:
        with open(path, "w", encoding="utf-8") as f:
            f.write("\n".join(lines))
    except Exception:
        # If writing fails, silently ignore (not critical for predictions)
        pass


def _ensure_fitted() -> None:
    """
    Ensure that COEFFS is populated. If empty, attempt to load and fit.
    Also writes/updates /app/explain.md with the fitted coefficients.
    """
    global COEFFS
    if COEFFS:
        return
    coeffs = _load_and_fit("/app/data")
    COEFFS = coeffs
    # Best-effort write explain.md so the fitted numbers are visible
    _write_explain_md(COEFFS, "/app/explain.md")


def law(input_data: list[dict[str, float]], group: str) -> list[dict[str, float]]:
    """
    Predicts output variables based on input variables according to a discovered scaling law.

    Args:
        input_data: A list of dictionaries, where each dictionary is a single data
                    point containing input variable names as keys and their
                    corresponding values. Required keys per item:
                    - 'lr'
                    - 'bsz'
                    - 'data_size'
                    - 'non_embedding_param_size'
        group: The name of the experimental group for which to make predictions.
               The functional form of the law is the same for all groups, but the
               coefficients can differ per group.

    Returns:
        A list of dictionaries, corresponding to the input_data list, with each
        dictionary containing the predicted output variable(s): {'lm_loss': float}.
    """
    _ensure_fitted()

    # Select coefficients for the requested group, fallback to global, then to any available group
    c = COEFFS.get(group)
    if c is None:
        c = COEFFS.get(GLOBAL_GROUP_KEY)
    if c is None and COEFFS:
        # Fallback to any one group deterministically
        any_group = sorted(COEFFS.keys())[0]
        c = COEFFS[any_group]
    if c is None:
        # Last-resort defaults (should not happen if fitting succeeded)
        c = {
            "beta0": 1.0,
            "a_lr": 0.0,
            "a2_lr2": 0.1,
            "b_bsz": -0.02,
            "c_data": -0.1,
            "d_param": -0.1,
        }

    beta0 = c["beta0"]
    a_lr = c["a_lr"]
    a2_lr2 = c["a2_lr2"]
    b_bsz = c["b_bsz"]
    c_data = c["c_data"]
    d_param = c["d_param"]

    outputs: list[dict[str, float]] = []
    for item in input_data:
        try:
            lr = float(item["lr"])
            bsz = float(item["bsz"])
            data_size = float(item["data_size"])
            non_emb_params = float(item["non_embedding_param_size"])
        except Exception as e:
            raise ValueError("Each input item must contain 'lr', 'bsz', 'data_size', and 'non_embedding_param_size' as numeric values.") from e

        z_lr = _safe_log(lr)
        z_bsz = _safe_log(bsz)
        z_data = _safe_log(data_size)
        z_param = _safe_log(non_emb_params)

        ln_loss = (
            beta0
            + a_lr * z_lr
            + a2_lr2 * (z_lr * z_lr)
            + b_bsz * z_bsz
            + c_data * z_data
            + d_param * z_param
        )
        loss = float(math.exp(ln_loss))
        outputs.append({"lm_loss": loss})

    return outputs


# Ensure coefficients are fitted when the module is imported, so explain.md is generated.
try:
    _ensure_fitted()
except Exception:
    # Don't fail import in case of transient dataset issues; prediction will use defaults if necessary.
    pass
#2 Run 2 R² = -0.773483
#3 Run 3 R² = -0.773483
#4 Run 4 R² = -0.773483
#5 Run 5 R² = -1.000000