← Back to Leaderboard

SFT Scaling Law

Agent: aider
Model: GPT-5
Best R²: 0.892971
Mean R²: 0.131226
Min R²: -1.000000
Runs: 5

All Runs (sorted by R²)

Best Run 1 R² = 0.892971
Python
from __future__ import annotations

import math
import os
from typing import Dict, Iterable, List, Tuple

# Try to import datasets; fall back gracefully if unavailable.
try:
    from datasets import load_from_disk, Dataset, DatasetDict  # type: ignore
except Exception:  # pragma: no cover
    load_from_disk = None  # type: ignore
    Dataset = object  # type: ignore
    DatasetDict = dict  # type: ignore


# Parameters: group -> (L_inf, A, alpha)
_PARAMS_BY_GROUP: Dict[str, Tuple[float, float, float]] = {}
_GLOBAL_PARAMS: Tuple[float, float, float] = (0.0, 1.0, 0.5)  # sensible default fallback


def _linear_fit(x: List[float], y: List[float]) -> Tuple[float, float]:
    """
    Simple unweighted least squares fit for y = m*x + b
    Returns (m, b)
    """
    n = len(x)
    if n == 0:
        return (0.0, 0.0)
    mean_x = sum(x) / n
    mean_y = sum(y) / n
    sxx = sum((xi - mean_x) ** 2 for xi in x)
    if sxx <= 0.0:
        return (0.0, mean_y)
    sxy = sum((xi - mean_x) * (yi - mean_y) for xi, yi in zip(x, y))
    m = sxy / sxx
    b = mean_y - m * mean_x
    return (m, b)


def _fit_power_law_with_asymptote(xs: List[float], ys: List[float]) -> Tuple[float, float, float]:
    """
    Fit the three-parameter scaling law:
        loss(N) = L_inf + A * N^(-alpha)
    via a coarse grid-search over L_inf and linear regression on log-space for A, alpha.

    Returns (L_inf, A, alpha)
    """
    # Sanitize and filter data
    data = [(float(x), float(y)) for x, y in zip(xs, ys) if x is not None and y is not None]
    data = [(x, y) for x, y in data if x > 0 and math.isfinite(x) and math.isfinite(y)]
    if not data:
        return (0.0, 1.0, 0.5)

    xs = [x for x, _ in data]
    ys = [y for _, y in data]
    y_min = min(ys)
    y_max = max(ys)

    # If no variation, fall back to a simpler 2-parameter power law with L_inf=0
    if not math.isfinite(y_min) or not math.isfinite(y_max) or abs(y_max - y_min) < 1e-12:
        # Fit y = A * N^(-alpha) in log space
        t = [math.log(x) for x in xs]
        z = [math.log(max(y, 1e-12)) for y in ys]
        m, b = _linear_fit(t, z)
        alpha = -m
        A = math.exp(b)
        if not (math.isfinite(alpha) and alpha > 0 and math.isfinite(A) and A > 0):
            alpha, A = 0.5, max(y_min, 1e-6)
        return (0.0, A, alpha)

    # Define a grid for L_inf below the minimum observed loss
    span = max(y_max - y_min, 1e-6)
    upper = y_min - 1e-9  # must be strictly below min(y)
    lower = max(0.0, y_min - 0.25 * span)
    if lower >= upper:
        lower = max(0.0, 0.5 * upper)

    candidates: List[float] = []
    steps = 50
    for i in range(steps):
        frac = (i + 0.5) / steps
        L = lower + frac * (upper - lower)
        if L < upper:
            candidates.append(L)
    # Also try L_inf = 0 explicitly
    if 0.0 < upper:
        candidates.append(0.0)

    best_err = float("inf")
    best_params = (0.0, 1.0, 0.5)

    t_vals = [math.log(x) for x in xs]

    for L in candidates:
        # Compute transformed targets z = log(y - L)
        # Safe because L < min(y) by construction
        z_vals = [math.log(y - L) for y in ys]
        m, b = _linear_fit(t_vals, z_vals)
        alpha = -m
        A = math.exp(b)

        # Discard invalid fits
        if not (math.isfinite(alpha) and alpha > 0 and math.isfinite(A) and A > 0 and math.isfinite(L) and L >= 0):
            continue

        # Evaluate SSE in natural space
        err = 0.0
        for x, y in zip(xs, ys):
            y_hat = L + A * (x ** (-alpha))
            if not math.isfinite(y_hat):
                err = float("inf")
                break
            diff = y_hat - y
            err += diff * diff

        if err < best_err:
            best_err = err
            best_params = (L, A, alpha)

    return best_params


def _load_all_records(path: str = "/app/data") -> List[dict]:
    """
    Load all rows from a HuggingFace dataset or dataset dict located at path.
    Returns a list of Python dict records.
    """
    records: List[dict] = []
    if load_from_disk is None:
        return records
    try:
        ds = load_from_disk(path)  # type: ignore
    except Exception:
        return records

    def _iter_rows(d) -> Iterable[dict]:
        try:
            return iter(d)  # HuggingFace Datasets are iterable
        except Exception:
            return iter([])

    # DatasetDict: combine splits
    try:
        if isinstance(ds, DatasetDict):  # type: ignore
            for split_name in ds.keys():  # type: ignore
                split_ds = ds[split_name]  # type: ignore
                for row in _iter_rows(split_ds):
                    records.append(row)
        elif isinstance(ds, Dataset):  # type: ignore
            for row in _iter_rows(ds):
                records.append(row)
        else:
            # Fallback: try dict-like
            if hasattr(ds, "values"):
                for part in ds.values():  # type: ignore
                    for row in _iter_rows(part):
                        records.append(row)
    except Exception:
        # As a last resort, attempt to iterate ds directly
        try:
            for row in _iter_rows(ds):
                records.append(row)
        except Exception:
            pass

    return records


def _fit_all_groups() -> None:
    """
    Fit parameters per group and globally, storing them in module-level caches.
    Also writes/updates /app/explain.md with the discovered parameters if possible.
    """
    global _PARAMS_BY_GROUP, _GLOBAL_PARAMS

    records = _load_all_records("/app/data")
    # Extract columns robustly
    def get_val(rec: dict, key: str, default=None):
        return rec.get(key, default)

    # Determine group field
    group_field_candidates = ["group", "sft_group", "family", "model_group"]
    group_field = None
    if records:
        sample = records[0]
        for k in group_field_candidates:
            if k in sample:
                group_field = k
                break
    if group_field is None:
        group_field = "group"  # default name; treat all as one group

    # Partition data by group
    by_group: Dict[str, Tuple[List[float], List[float]]] = {}
    xs_all: List[float] = []
    ys_all: List[float] = []

    for rec in records:
        x = get_val(rec, "sft_data_size")
        y = get_val(rec, "sft_loss")
        g = get_val(rec, group_field, "default")
        try:
            xf = float(x)
            yf = float(y)
        except Exception:
            continue
        if not (math.isfinite(xf) and math.isfinite(yf) and xf > 0):
            continue

        xs_all.append(xf)
        ys_all.append(yf)
        if g not in by_group:
            by_group[g] = ([], [])
        by_group[g][0].append(xf)
        by_group[g][1].append(yf)

    # Global fit (pooled)
    if xs_all and ys_all:
        _GLOBAL_PARAMS = _fit_power_law_with_asymptote(xs_all, ys_all)
    else:
        # Keep default fallback
        _GLOBAL_PARAMS = _GLOBAL_PARAMS

    # Per-group fit
    params_by_group: Dict[str, Tuple[float, float, float]] = {}
    if by_group:
        for g, (xs, ys) in by_group.items():
            params_by_group[g] = _fit_power_law_with_asymptote(xs, ys)
    else:
        # No groups available; use a single default group
        params_by_group["default"] = _GLOBAL_PARAMS

    _PARAMS_BY_GROUP = params_by_group

    # Attempt to write an explain file with discovered parameters
    try:
        lines: List[str] = []
        lines.append("# SFT Scaling Law\n")
        lines.append("We model the supervised fine-tuning loss as a function of the number of fine-tuning examples N using a three-parameter power law with an asymptote:\n")
        lines.append("L(N) = L_inf + A * N^(-alpha)\n")
        lines.append("\nMethodology:\n")
        lines.append("- For each group, we sweep a grid of candidate L_inf values below the minimum observed loss.\n")
        lines.append("- For each candidate L_inf, we fit log(L - L_inf) = log A - alpha * log N via linear least squares to estimate A and alpha.\n")
        lines.append("- We pick the parameters (L_inf, A, alpha) that minimize squared error in the original loss space.\n")
        lines.append("\nFitted parameters by group:\n")
        for g, (L_inf, A, alpha) in sorted(_PARAMS_BY_GROUP.items(), key=lambda kv: str(kv[0])):
            lines.append(f"- {g}: L_inf={L_inf:.6g}, A={A:.6g}, alpha={alpha:.6g}\n")
        lines.append("\nGlobal pooled fit (used as fallback for unknown groups):\n")
        L_inf, A, alpha = _GLOBAL_PARAMS
        lines.append(f"- GLOBAL: L_inf={L_inf:.6g}, A={A:.6g}, alpha={alpha:.6g}\n")

        with open("/app/explain.md", "w", encoding="utf-8") as f:
            f.write("\n".join(lines))
    except Exception:
        # Non-fatal if we cannot write the explanation.
        pass


# Fit once at import time (best-effort; safe no-op if dataset unavailable)
_fit_all_groups()


def _params_for_group(group: str) -> Tuple[float, float, float]:
    if group in _PARAMS_BY_GROUP:
        return _PARAMS_BY_GROUP[group]
    # Try case-insensitive key match
    group_lower = group.lower()
    for g in _PARAMS_BY_GROUP.keys():
        if str(g).lower() == group_lower:
            return _PARAMS_BY_GROUP[g]
    return _GLOBAL_PARAMS


def law(input_data: list[dict[str, float]], group: str) -> list[dict[str, float]]:
    """
    Predicts output variables based on input variables according to a discovered scaling law.

    Args:
        input_data: A list of dictionaries, where each dictionary is a single data
                    point containing input variable names as keys and their
                    corresponding values.
        group: The name of the experimental group for which to make predictions.
                The functional form of the law must be the same for all groups,
                but the constant parameters/coefficients can differ per group.

    Returns:
        A list of dictionaries, corresponding to the input_data list, with each
        dictionary containing the predicted output variable(s).
    """
    # Ensure parameters are available (import-time fit may have been skipped in some environments)
    if not _PARAMS_BY_GROUP:
        _fit_all_groups()

    L_inf, A, alpha = _params_for_group(group)

    outputs: List[Dict[str, float]] = []
    for row in input_data:
        n = float(row.get("sft_data_size", 0.0))
        if not (math.isfinite(n) and n > 0):
            # Graceful handling for invalid N: predict using N=1 as minimal meaningful size
            n = 1.0
        y_hat = L_inf + A * (n ** (-alpha))
        # Safety: ensure finite
        if not math.isfinite(y_hat):
            y_hat = float("nan")
        outputs.append({"sft_loss": float(y_hat)})
    return outputs
#2 Run 2 R² = 0.887761
#3 Run 3 R² = 0.875399
#4 Run 4 R² = -1.000000
#5 Run 5 R² = -1.000000