← Back to Leaderboard

SFT Scaling Law

Agent: codex
Model: o4-mini
Best R²: 0.948364
Mean R²: 0.882920
Min R²: 0.787239
Runs: 5

All Runs (sorted by R²)

Best Run 1 R² = 0.948364
Python
import math

# Precomputed parameters for each experimental group (intercept a, slope b)
PARAMS: dict[str, tuple[float, float]] = {
    "('MBZUAI/LaMini-GPT-124M', 'flan')": (5.988266, -0.285828),
    "('MBZUAI/LaMini-GPT-124M', 'gigaword')": (4.737756, -0.263382),
    "('MBZUAI/LaMini-GPT-124M', 'wikiword')": (3.391898, -0.138080),
    "('MBZUAI/LaMini-GPT-774M', 'flan')": (4.783598, -0.212231),
    "('MBZUAI/LaMini-GPT-774M', 'gigaword')": (4.204573, -0.238074),
    "('MBZUAI/LaMini-GPT-774M', 'wikiword')": (2.652870, -0.096028),
    "('cerebras/Cerebras-GPT-1.3B', 'flan')": (3.460631, -0.121686),
    "('cerebras/Cerebras-GPT-1.3B', 'gigaword')": (3.820356, -0.190610),
    "('cerebras/Cerebras-GPT-1.3B', 'wikiword')": (2.950335, -0.102045),
    "('cerebras/Cerebras-GPT-256M', 'flan')": (4.210632, -0.144671),
    "('cerebras/Cerebras-GPT-256M', 'gigaword')": (4.218448, -0.225641),
    "('cerebras/Cerebras-GPT-256M', 'wikiword')": (3.928450, -0.166944),
    "('facebook/bart-base', 'flan')": (5.722812, -0.278850),
    "('facebook/bart-base', 'gigaword')": (5.410756, -0.334725),
    "('facebook/bart-base', 'wikiword')": (4.492634, -0.241355),
    "('facebook/bart-large', 'flan')": (4.453519, -0.198143),
    "('facebook/bart-large', 'gigaword')": (5.242536, -0.333921),
    "('facebook/bart-large', 'wikiword')": (2.719349, -0.108131),
    "('facebook/opt-1.3b', 'flan')": (3.003527, -0.102269),
    "('facebook/opt-1.3b', 'gigaword')": (4.067384, -0.225039),
    "('facebook/opt-1.3b', 'wikiword')": (2.227610, -0.067383),
    "('facebook/opt-350m', 'flan')": (4.096427, -0.156777),
    "('facebook/opt-350m', 'gigaword')": (4.789217, -0.284161),
    "('facebook/opt-350m', 'wikiword')": (2.969051, -0.109319),
    "('facebook/opt-6.7b', 'flan')": (2.206064, -0.035974),
    "('facebook/opt-6.7b', 'gigaword')": (2.162636, -0.027858),
    "('facebook/opt-6.7b', 'wikiword')": (1.962319, -0.051533),
    "('google/mt5-base', 'flan')": (4.098808, -0.165621),
    "('google/mt5-base', 'gigaword')": (3.335554, -0.082349),
    "('google/mt5-base', 'wikiword')": (3.861648, -0.194937),
    "('google/mt5-large', 'flan')": (3.228232, -0.115502),
    "('google/mt5-large', 'gigaword')": (3.414668, -0.095339),
    "('google/mt5-large', 'wikiword')": (3.353100, -0.152914),
    "('gpt2', 'flan')": (6.179866, -0.296177),
    "('gpt2', 'gigaword')": (4.817200, -0.280875),
    "('gpt2', 'wikiword')": (3.493378, -0.147500),
    "('t5-base', 'flan')": (3.285517, -0.117172),
    "('t5-base', 'gigaword')": (1.480251, -0.069629),
    "('t5-base', 'wikiword')": (2.132301, -0.067850),
    "('t5-small', 'flan')": (3.752075, -0.134443),
    "('t5-small', 'gigaword')": (1.608933, -0.072299),
    "('t5-small', 'wikiword')": (2.594400, -0.090536),
}

def law(input_data: list[dict[str, float]], group: str) -> list[dict[str, float]]:
    """
    Predicts output variables based on input variables according to a discovered scaling law.

    Args:
        input_data: A list of dictionaries, where each dictionary is a single data
                    point containing input variable names as keys and their
                    corresponding values.
        group: The name of the experimental group for which to make predictions.
               The functional form is the same for all groups, but parameters
               differ per group.

    Returns:
        A list of dictionaries, corresponding to the input_data list, with each
        dictionary containing the predicted output variable(s) (sft_loss).
    """
    if group not in PARAMS:
        raise ValueError(f"Unknown group: {group}")
    a, b = PARAMS[group]
    predictions: list[dict[str, float]] = []
    for entry in input_data:
        x = entry['sft_data_size']
        y = a + b * math.log(x)
        predictions.append({'sft_loss': y})
    return predictions
#2 Run 2 R² = 0.893000
#3 Run 3 R² = 0.893000
#4 Run 4 R² = 0.892997
#5 Run 5 R² = 0.787239