← Back to Leaderboard

SFT Scaling Law

Agent: claude-code
Model: claude-sonnet-4-5
Best R²: 0.948364
Mean R²: 0.916139
Min R²: 0.787239
Runs: 5

All Runs (sorted by R²)

Best Run 1 R² = 0.948364
Python
import json
import math

# Load fitted parameters from the training data
# These parameters were obtained by fitting L(N) = a * log(N) + b to each group
FITTED_PARAMS = {
    "('MBZUAI/LaMini-GPT-124M', 'flan')": {"a": -0.2858284253987973, "b": 5.988265549351065},
    "('MBZUAI/LaMini-GPT-124M', 'gigaword')": {"a": -0.2633818402656468, "b": 4.737755542012152},
    "('MBZUAI/LaMini-GPT-124M', 'wikiword')": {"a": -0.1380796078939516, "b": 3.391898490213529},
    "('MBZUAI/LaMini-GPT-774M', 'flan')": {"a": -0.21223122991733276, "b": 4.783598287214052},
    "('MBZUAI/LaMini-GPT-774M', 'gigaword')": {"a": -0.23807369821093685, "b": 4.2045733043013165},
    "('MBZUAI/LaMini-GPT-774M', 'wikiword')": {"a": -0.09602793909920557, "b": 2.652870267456121},
    "('cerebras/Cerebras-GPT-1.3B', 'flan')": {"a": -0.12168554599239426, "b": 3.4606313079569837},
    "('cerebras/Cerebras-GPT-1.3B', 'gigaword')": {"a": -0.19060958285317242, "b": 3.820355959611436},
    "('cerebras/Cerebras-GPT-1.3B', 'wikiword')": {"a": -0.10204513742983291, "b": 2.950335143562661},
    "('cerebras/Cerebras-GPT-256M', 'flan')": {"a": -0.14467116003927433, "b": 4.210632462544093},
    "('cerebras/Cerebras-GPT-256M', 'gigaword')": {"a": -0.22564133332553715, "b": 4.218447739736505},
    "('cerebras/Cerebras-GPT-256M', 'wikiword')": {"a": -0.16694369363861158, "b": 3.9284497987057874},
    "('facebook/bart-base', 'flan')": {"a": -0.27885014029903604, "b": 5.722811837645894},
    "('facebook/bart-base', 'gigaword')": {"a": -0.3347248647552073, "b": 5.410755825604152},
    "('facebook/bart-base', 'wikiword')": {"a": -0.2413552868912743, "b": 4.4926337354168595},
    "('facebook/bart-large', 'flan')": {"a": -0.19814331941245988, "b": 4.453518961526505},
    "('facebook/bart-large', 'gigaword')": {"a": -0.3339209236977352, "b": 5.242535980974371},
    "('facebook/bart-large', 'wikiword')": {"a": -0.10813142202742225, "b": 2.7193492499816334},
    "('facebook/opt-1.3b', 'flan')": {"a": -0.10226900935941804, "b": 3.0035271247006574},
    "('facebook/opt-1.3b', 'gigaword')": {"a": -0.22503850880208404, "b": 4.067383747817735},
    "('facebook/opt-1.3b', 'wikiword')": {"a": -0.06738256744904991, "b": 2.227609751673505},
    "('facebook/opt-350m', 'flan')": {"a": -0.1567771098875299, "b": 4.096427281007177},
    "('facebook/opt-350m', 'gigaword')": {"a": -0.28416051180558977, "b": 4.789217216189481},
    "('facebook/opt-350m', 'wikiword')": {"a": -0.10931922565009441, "b": 2.969051299001184},
    "('facebook/opt-6.7b', 'flan')": {"a": -0.035974453233156484, "b": 2.2060635294933304},
    "('facebook/opt-6.7b', 'gigaword')": {"a": -0.027858065828614262, "b": 2.1626361790690503},
    "('facebook/opt-6.7b', 'wikiword')": {"a": -0.05153332226677372, "b": 1.9623193491235948},
    "('google/mt5-base', 'flan')": {"a": -0.16562129412487037, "b": 4.0988075329513345},
    "('google/mt5-base', 'gigaword')": {"a": -0.08234869176487448, "b": 3.3355539218016768},
    "('google/mt5-base', 'wikiword')": {"a": -0.19493726211114437, "b": 3.861648173041152},
    "('google/mt5-large', 'flan')": {"a": -0.11550187851501488, "b": 3.2282318950626876},
    "('google/mt5-large', 'gigaword')": {"a": -0.09533853604323887, "b": 3.4146676286886763},
    "('google/mt5-large', 'wikiword')": {"a": -0.15291379226040927, "b": 3.353099664653985},
    "('gpt2', 'flan')": {"a": -0.29617705608594097, "b": 6.179866386147315},
    "('gpt2', 'gigaword')": {"a": -0.28087485939518764, "b": 4.8171995484035675},
    "('gpt2', 'wikiword')": {"a": -0.14749971599228653, "b": 3.4933775691623454},
    "('t5-base', 'flan')": {"a": -0.11717228671842463, "b": 3.2855166649939935},
    "('t5-base', 'gigaword')": {"a": -0.0696292965522014, "b": 1.480251152476475},
    "('t5-base', 'wikiword')": {"a": -0.06785004652930211, "b": 2.132300508433401},
    "('t5-small', 'flan')": {"a": -0.13444290860067154, "b": 3.752075115263242},
    "('t5-small', 'gigaword')": {"a": -0.07229854246705678, "b": 1.6089334546678618},
    "('t5-small', 'wikiword')": {"a": -0.09053553061557014, "b": 2.594400260204647},
}


def law(input_data: list[dict[str, float]], group: str) -> list[dict[str, float]]:
    """
    Predicts output variables based on input variables according to a discovered scaling law.

    The discovered scaling law follows a logarithmic relationship:
        L(N) = a * log(N) + b

    where:
        - L is the predicted SFT loss (sft_loss)
        - N is the number of examples in the fine-tuning dataset (sft_data_size)
        - a and b are group-specific parameters fitted from training data
        - log is the natural logarithm

    Args:
        input_data: A list of dictionaries, where each dictionary is a single data
                    point containing input variable names as keys and their
                    corresponding values.
        group: The name of the experimental group for which to make predictions.
                The functional form of the law must be the same for all groups,
                but the constant parameters/coefficients can differ per group.

    Returns:
        A list of dictionaries, corresponding to the input_data list, with each
        dictionary containing the predicted output variable(s).
    """
    # Get the parameters for this group
    if group not in FITTED_PARAMS:
        raise ValueError(f"Unknown group: {group}. Available groups: {list(FITTED_PARAMS.keys())}")

    params = FITTED_PARAMS[group]
    a = params["a"]
    b = params["b"]

    # Make predictions for each input data point
    predictions = []
    for data_point in input_data:
        # Extract the input variable (sft_data_size)
        N = data_point["sft_data_size"]

        # Apply the scaling law: L(N) = a * log(N) + b
        sft_loss = a * math.log(N) + b

        # Return the prediction
        predictions.append({"sft_loss": sft_loss})

    return predictions
#2 Run 2 R² = 0.948364
#3 Run 3 R² = 0.948364
#4 Run 4 R² = 0.948364
#5 Run 5 R² = 0.787239