← Back to Leaderboard

SFT Scaling Law

Agent: terminus-2
Model: GPT-5
Best R²: 0.960281
Mean R²: 0.856388
Min R²: 0.787239
Runs: 3

All Runs (sorted by R²)

Best Run 1 R² = 0.960281
Python
# Auto-generated scaling law for SFT loss
# Model: A + B * N^-alpha
from typing import List, Dict

# Fitted parameters per group
_PARAMS = {
    "('MBZUAI/LaMini-GPT-124M', 'flan')": {'A': -9287.970911, 'B': 9293.959488, 'alpha': 3.076250129e-05},
    "('MBZUAI/LaMini-GPT-124M', 'gigaword')": {'A': -10371.5707, 'B': 10376.30869, 'alpha': 2.53885626e-05},
    "('MBZUAI/LaMini-GPT-124M', 'wikiword')": {'A': -4625.015479, 'B': 4628.407534, 'alpha': 2.984110982e-05},
    "('MBZUAI/LaMini-GPT-774M', 'flan')": {'A': -9172.408439, 'B': 9177.192211, 'alpha': 2.313061965e-05},
    "('MBZUAI/LaMini-GPT-774M', 'gigaword')": {'A': -9623.28162, 'B': 9627.486398, 'alpha': 2.473385708e-05},
    "('MBZUAI/LaMini-GPT-774M', 'wikiword')": {'A': -7.36673275, 'B': 10.05976737, 'alpha': 0.01050226284},
    "('cerebras/Cerebras-GPT-1.3B', 'flan')": {'A': -4121.931225, 'B': 4125.391992, 'alpha': 2.950457285e-05},
    "('cerebras/Cerebras-GPT-1.3B', 'gigaword')": {'A': -5878.067704, 'B': 5881.88829, 'alpha': 3.241559695e-05},
    "('cerebras/Cerebras-GPT-1.3B', 'wikiword')": {'A': -2801.316556, 'B': 2804.267032, 'alpha': 3.640117085e-05},
    "('cerebras/Cerebras-GPT-256M', 'flan')": {'A': -4611.474391, 'B': 4615.685187, 'alpha': 3.135204722e-05},
    "('cerebras/Cerebras-GPT-256M', 'gigaword')": {'A': -8313.964541, 'B': 8318.18321, 'alpha': 2.713277306e-05},
    "('cerebras/Cerebras-GPT-256M', 'wikiword')": {'A': -18.85652913, 'B': 22.83678187, 'alpha': 0.007851560092},
    "('facebook/bart-base', 'flan')": {'A': -8710.157573, 'B': 8715.880715, 'alpha': 3.200248092e-05},
    "('facebook/bart-base', 'gigaword')": {'A': -9243.220356, 'B': 9248.631552, 'alpha': 3.62034338e-05},
    "('facebook/bart-base', 'wikiword')": {'A': 0.2505127082, 'B': 5.861960896, 'alpha': 0.1201854396},
    "('facebook/bart-large', 'flan')": {'A': -4109.920415, 'B': 4114.374298, 'alpha': 4.817980936e-05},
    "('facebook/bart-large', 'gigaword')": {'A': -8915.743973, 'B': 8920.986972, 'alpha': 3.744345395e-05},
    "('facebook/bart-large', 'wikiword')": {'A': 0.7814640543, 'B': 2.620750818, 'alpha': 0.1152038048},
    "('facebook/opt-1.3b', 'flan')": {'A': -4047.521923, 'B': 4050.525548, 'alpha': 2.525409335e-05},
    "('facebook/opt-1.3b', 'gigaword')": {'A': -5808.928716, 'B': 5812.996425, 'alpha': 3.872644109e-05},
    "('facebook/opt-1.3b', 'wikiword')": {'A': -1.456883583, 'B': 3.745412801, 'alpha': 0.02195676682},
    "('facebook/opt-350m', 'flan')": {'A': -5046.306355, 'B': 5050.402958, 'alpha': 3.105100761e-05},
    "('facebook/opt-350m', 'gigaword')": {'A': -8546.627786, 'B': 8551.417352, 'alpha': 3.32394917e-05},
    "('facebook/opt-350m', 'wikiword')": {'A': -1.12612954, 'B': 4.260565516, 'alpha': 0.03532388588},
    "('facebook/opt-6.7b', 'flan')": {'A': -6.628198975, 'B': 8.840212818, 'alpha': 0.004229133986},
    "('facebook/opt-6.7b', 'gigaword')": {'A': -1042.934868, 'B': 1045.09753, 'alpha': 2.66621148e-05},
    "('facebook/opt-6.7b', 'wikiword')": {'A': 0.2693294718, 'B': 1.788104434, 'alpha': 0.04220831611},
    "('google/mt5-base', 'flan')": {'A': -2046.453987, 'B': 2050.553311, 'alpha': 8.082852091e-05},
    "('google/mt5-base', 'gigaword')": {'A': -2558.786766, 'B': 2562.122418, 'alpha': 3.215005554e-05},
    "('google/mt5-base', 'wikiword')": {'A': -34.04468734, 'B': 37.94761142, 'alpha': 0.005395576769},
    "('google/mt5-large', 'flan')": {'A': -3272.79782, 'B': 3276.026208, 'alpha': 3.526798902e-05},
    "('google/mt5-large', 'gigaword')": {'A': -2859.850037, 'B': 2863.264817, 'alpha': 3.330684106e-05},
    "('google/mt5-large', 'wikiword')": {'A': -1.180043896, 'B': 4.866217417, 'alpha': 0.04883425147},
    "('gpt2', 'flan')": {'A': -10281.29052, 'B': 10287.47069, 'alpha': 2.879731311e-05},
    "('gpt2', 'gigaword')": {'A': -8699.204248, 'B': 8704.021771, 'alpha': 3.22787079e-05},
    "('gpt2', 'wikiword')": {'A': -2263.843178, 'B': 2267.336925, 'alpha': 6.509264453e-05},
    "('t5-base', 'flan')": {'A': -3875.638003, 'B': 3878.923652, 'alpha': 3.02155981e-05},
    "('t5-base', 'gigaword')": {'A': 0.4167409862, 'B': 1.823379391, 'alpha': 0.1674599731},
    "('t5-base', 'wikiword')": {'A': -2315.534964, 'B': 2317.66734, 'alpha': 2.92828973e-05},
    "('t5-small', 'flan')": {'A': -4336.333336, 'B': 4340.085568, 'alpha': 3.098564388e-05},
    "('t5-small', 'gigaword')": {'A': 0.4009177682, 'B': 1.775741828, 'alpha': 0.1343979639},
    "('t5-small', 'wikiword')": {'A': -2123.447508, 'B': 2126.042054, 'alpha': 4.260045172e-05},
}

def _predict_one(sft_data_size: float, p: dict) -> float:
    # Ensure positive size
    n = max(float(sft_data_size), 1e-12)
    return float(p['A'] + p['B'] * (n ** (-p['alpha'])))

def law(input_data: list[dict[str, float]], group: str) -> list[dict[str, float]]:
    """
    Predicts output variables based on input variables according to a discovered scaling law.

    Args:
        input_data: A list of dictionaries, where each dictionary is a single data
                    point containing input variable names as keys and their
                    corresponding values.
        group: The name of the experimental group for which to make predictions.
                The functional form of the law must be the same for all groups,
                but the constant parameters/coefficients can differ per group.

    Returns:
        A list of dictionaries, corresponding to the input_data list, with each
        dictionary containing the predicted output variable(s).
    """
    # If unseen group, back off to global average of parameters
    if group in _PARAMS:
        p = _PARAMS[group]
    else:
        # simple average parameters
        if not hasattr(law, '_avg_params'):
            import numpy as _np
            A = _np.mean([v['A'] for v in _PARAMS.values()])
            B = _np.mean([v['B'] for v in _PARAMS.values()])
            alpha = _np.mean([v['alpha'] for v in _PARAMS.values()])
            law._avg_params = {'A': float(A), 'B': float(B), 'alpha': float(alpha)}
        p = law._avg_params

    outputs: list[dict[str, float]] = []
    for row in input_data:
        n = row.get('sft_data_size')
        if n is None:
            raise KeyError("Each input row must contain 'sft_data_size'.")
        pred = _predict_one(n, p)
        outputs.append({'sft_loss': float(pred)})
    return outputs
#2 Run 2 R² = 0.821644
#3 Run 3 R² = 0.787239