from __future__ import annotations
from typing import List, Dict
# Discovered scaling law: sft_loss(N) = c + a * N^{-b}
# Same functional form across groups; (a,b,c) differ per group.
PARAMS: Dict[str, Dict[str, float]] = {
"('MBZUAI/LaMini-GPT-124M', 'flan')": {'a': 7.79709296, 'b': 0.14366504, 'c': 1.15127774},
"('MBZUAI/LaMini-GPT-124M', 'gigaword')": {'a': 7.18065251, 'b': 0.13966749, 'c': 0.20943581},
"('MBZUAI/LaMini-GPT-124M', 'wikiword')": {'a': 3.50492382, 'b': 0.09301953, 'c': 0.59486280},
"('MBZUAI/LaMini-GPT-774M', 'flan')": {'a': 5.55096616, 'b': 0.11852902, 'c': 0.88700526},
"('MBZUAI/LaMini-GPT-774M', 'gigaword')": {'a': 6.33245067, 'b': 0.13250068, 'c': 0.04341208},
"('MBZUAI/LaMini-GPT-774M', 'wikiword')": {'a': 2.57317908, 'b': 0.07191404, 'c': 0.42185299},
"('cerebras/Cerebras-GPT-1.3B', 'flan')": {'a': 3.13641643, 'b': 0.08539028, 'c': 0.88121614},
"('cerebras/Cerebras-GPT-1.3B', 'gigaword')": {'a': 4.84266305, 'b': 0.11468163, 'c': 0.31586501},
"('cerebras/Cerebras-GPT-1.3B', 'wikiword')": {'a': 2.72740881, 'b': 0.07469495, 'c': 0.61750148},
"('cerebras/Cerebras-GPT-256M', 'flan')": {'a': 3.77416454, 'b': 0.09141127, 'c': 1.21118553},
"('cerebras/Cerebras-GPT-256M', 'gigaword')": {'a': 5.88436542, 'b': 0.12720263, 'c': 0.22929230},
"('cerebras/Cerebras-GPT-256M', 'wikiword')": {'a': 4.15198671, 'b': 0.10663355, 'c': 0.78487709},
"('facebook/bart-base', 'flan')": {'a': 7.45393345, 'b': 0.14374514, 'c': 1.05100910},
"('facebook/bart-base', 'gigaword')": {'a': 9.47113095, 'b': 0.16403119, 'c': 0.06972433},
"('facebook/bart-base', 'wikiword')": {'a': 5.97887685, 'b': 0.14140928, 'c': 0.54784037},
"('facebook/bart-large', 'flan')": {'a': 4.99689388, 'b': 0.11656868, 'c': 0.85326329},
"('facebook/bart-large', 'gigaword')": {'a': 9.38268749, 'b': 0.16542420, 'c': -0.04288451},
"('facebook/bart-large', 'wikiword')": {'a': 2.62592585, 'b': 0.11454169, 'c': 0.77424909},
"('facebook/opt-1.3b', 'flan')": {'a': 2.72492496, 'b': 0.07484074, 'c': 0.67174921},
"('facebook/opt-1.3b', 'gigaword')": {'a': 5.77555355, 'b': 0.12986758, 'c': 0.16216071},
"('facebook/opt-1.3b', 'wikiword')": {'a': 2.02894972, 'b': 0.05450637, 'c': 0.36837840},
"('facebook/opt-350m', 'flan')": {'a': 4.04942870, 'b': 0.09803201, 'c': 0.96388081},
"('facebook/opt-350m', 'gigaword')": {'a': 7.65025332, 'b': 0.14974638, 'c': 0.11860471},
"('facebook/opt-350m', 'wikiword')": {'a': 2.83147284, 'b': 0.07932386, 'c': 0.57378977},
"('facebook/opt-6.7b', 'flan')": {'a': 1.51363810, 'b': 0.03174040, 'c': 0.74149915},
"('facebook/opt-6.7b', 'gigaword')": {'a': 1.40790357, 'b': 0.02514156, 'c': 0.78702090},
"('facebook/opt-6.7b', 'wikiword')": {'a': 1.75335535, 'b': 0.04356588, 'c': 0.30744922},
"('google/mt5-base', 'flan')": {'a': 4.12803145, 'b': 0.10607214, 'c': 0.96887260},
"('google/mt5-base', 'gigaword')": {'a': 2.35048143, 'b': 0.06291060, 'c': 1.24503412},
"('google/mt5-base', 'wikiword')": {'a': 4.86327770, 'b': 0.11878338, 'c': 0.37128462},
"('google/mt5-large', 'flan')": {'a': 2.98458153, 'b': 0.08255616, 'c': 0.74169984},
"('google/mt5-large', 'gigaword')": {'a': 2.64709116, 'b': 0.07015078, 'c': 1.12878106},
"('google/mt5-large', 'wikiword')": {'a': 3.77516678, 'b': 0.10151783, 'c': 0.41863249},
"('gpt2', 'flan')": {'a': 8.16685603, 'b': 0.14630670, 'c': 1.19283420},
"('gpt2', 'gigaword')": {'a': 7.62765686, 'b': 0.14681575, 'c': 0.13018064},
"('gpt2', 'wikiword')": {'a': 3.69790048, 'b': 0.09825580, 'c': 0.59712436},
"('t5-base', 'flan')": {'a': 3.08194773, 'b': 0.08089143, 'c': 0.71527394},
"('t5-base', 'gigaword')": {'a': 1.82066852, 'b': 0.16688442, 'c': 0.41517682},
"('t5-base', 'wikiword')": {'a': 2.06003085, 'b': 0.05427983, 'c': 0.24720659},
"('t5-small', 'flan')": {'a': 3.44637525, 'b': 0.09027300, 'c': 0.97762445},
"('t5-small', 'gigaword')": {'a': 1.77232820, 'b': 0.13527281, 'c': 0.40577385},
"('t5-small', 'wikiword')": {'a': 2.49570434, 'b': 0.06779296, 'c': 0.40605462},
}
DEFAULT: Dict[str, float] = {'a': 4.33379279, 'b': 0.10382283, 'c': 0.60135888}
def _predict_sft_loss(n: float, a: float, b: float, c: float) -> float:
if n <= 0:
# Guard against invalid inputs; fallback to asymptote
return float(c)
return float(c + a * (n ** (-b)))
def law(input_data: List[Dict[str, float]], group: str) -> List[Dict[str, float]]:
"""
Predicts output variables based on input variables according to a discovered scaling law.
Args:
input_data: A list of dictionaries, where each dictionary is a single data
point containing input variable names as keys and their
corresponding values.
group: The name of the experimental group for which to make predictions.
The functional form of the law must be the same for all groups,
but the constant parameters/coefficients can differ per group.
Returns:
A list of dictionaries, corresponding to the input_data list, with each
dictionary containing the predicted output variable(s).
"""
params = PARAMS.get(group, DEFAULT)
a, b, c = params['a'], params['b'], params['c']
out: List[Dict[str, float]] = []
for row in input_data:
n = float(row.get('sft_data_size', 0.0))
y = _predict_sft_loss(n, a, b, c)
out.append({'sft_loss': y})
return out