# Auto-generated scaling law for SFT loss vs data size
from __future__ import annotations
from math import pow
# Per-group parameters for L(N) = c + a * N^{-b}
_PARAMS = {
"('MBZUAI/LaMini-GPT-124M', 'flan')": {
"c": 1.4827993951275662,
"a": 8.681233620552247,
"b": 0.17615205560077954,
"mse": 0.07572095110708876,
"n": 12
},
"('MBZUAI/LaMini-GPT-124M', 'gigaword')": {
"c": 0.5778526375054923,
"a": 8.132271163266724,
"b": 0.17780779263551502,
"mse": 0.09155548707084517,
"n": 12
},
"('MBZUAI/LaMini-GPT-124M', 'wikiword')": {
"c": 1.2749505667194287,
"a": 4.096753821370224,
"b": 0.18127272693009286,
"mse": 0.007576015148309798,
"n": 12
},
"('MBZUAI/LaMini-GPT-774M', 'flan')": {
"c": 1.3969925089074628,
"a": 6.40984847550739,
"b": 0.17199807099464345,
"mse": 0.04227240896528709,
"n": 12
},
"('MBZUAI/LaMini-GPT-774M', 'gigaword')": {
"c": 0.4749899427965494,
"a": 7.316034511850357,
"b": 0.17951634849788112,
"mse": 0.06685341699938115,
"n": 12
},
"('MBZUAI/LaMini-GPT-774M', 'wikiword')": {
"c": 1.1984324837502354,
"a": 2.8225175602427934,
"b": 0.18381610159857534,
"mse": 0.002350416955268903,
"n": 12
},
"('cerebras/Cerebras-GPT-1.3B', 'flan')": {
"c": 1.5983096713993608,
"a": 3.625465436911771,
"b": 0.18232212054329097,
"mse": 0.006276539391354531,
"n": 12
},
"('cerebras/Cerebras-GPT-1.3B', 'gigaword')": {
"c": 0.873264704567136,
"a": 5.702649719858272,
"b": 0.17983169580523745,
"mse": 0.01956385652134642,
"n": 12
},
"('cerebras/Cerebras-GPT-1.3B', 'wikiword')": {
"c": 1.380007205548204,
"a": 3.0382321918197386,
"b": 0.18059759917861193,
"mse": 0.004495478358400602,
"n": 12
},
"('cerebras/Cerebras-GPT-256M', 'flan')": {
"c": 1.8824405987821344,
"a": 4.3105173431135455,
"b": 0.16798665859019904,
"mse": 0.016099566418512953,
"n": 12
},
"('cerebras/Cerebras-GPT-256M', 'gigaword')": {
"c": 0.7013636723381862,
"a": 6.83969648610096,
"b": 0.17912089400734008,
"mse": 0.040609922074825354,
"n": 12
},
"('cerebras/Cerebras-GPT-256M', 'wikiword')": {
"c": 1.3966611005609038,
"a": 4.908488983299831,
"b": 0.18345726653962038,
"mse": 0.00725928806266427,
"n": 12
},
"('facebook/bart-base', 'flan')": {
"c": 1.402782652932343,
"a": 8.376545770083204,
"b": 0.17972746603444834,
"mse": 0.044879946929858,
"n": 12
},
"('facebook/bart-base', 'gigaword')": {
"c": 0.2764984487592691,
"a": 10.213711557771816,
"b": 0.1847929973487988,
"mse": 0.09375044998301889,
"n": 12
},
"('facebook/bart-base', 'wikiword')": {
"c": 0.9710400825952036,
"a": 6.95185109881479,
"b": 0.19363752005641285,
"mse": 0.006138814059146752,
"n": 12
},
"('facebook/bart-large', 'flan')": {
"c": 1.3945609121941243,
"a": 5.826039279936208,
"b": 0.1781279206165522,
"mse": 0.014543204325175108,
"n": 12
},
"('facebook/bart-large', 'gigaword')": {
"c": 0.16518805426565897,
"a": 10.136677046755027,
"b": 0.18675428613221712,
"mse": 0.07801476549607378,
"n": 12
},
"('facebook/bart-large', 'wikiword')": {
"c": 1.1065786999405036,
"a": 3.072980337237203,
"b": 0.18460659464051513,
"mse": 0.0005493395674912767,
"n": 12
},
"('facebook/opt-1.3b', 'flan')": {
"c": 1.4345452517670585,
"a": 3.028618372301197,
"b": 0.1808380053964771,
"mse": 0.004029167599996951,
"n": 12
},
"('facebook/opt-1.3b', 'gigaword')": {
"c": 0.634715119130116,
"a": 6.7608938134760015,
"b": 0.1844253259847985,
"mse": 0.028949249451991333,
"n": 12
},
"('facebook/opt-1.3b', 'wikiword')": {
"c": 1.2108912294465461,
"a": 1.981892390506441,
"b": 0.18507486951364943,
"mse": 0.0010254072451437843,
"n": 12
},
"('facebook/opt-350m', 'flan')": {
"c": 1.6032673720312067,
"a": 4.698389125373991,
"b": 0.17197678324723842,
"mse": 0.018425011918658536,
"n": 12
},
"('facebook/opt-350m', 'gigaword')": {
"c": 0.4487475721624149,
"a": 8.612529868566897,
"b": 0.18513602069713941,
"mse": 0.05488526088852023,
"n": 12
},
"('facebook/opt-350m', 'wikiword')": {
"c": 1.318657032305419,
"a": 3.1918344343120673,
"b": 0.1840127387732732,
"mse": 0.0022594662870228117,
"n": 12
},
"('facebook/opt-6.7b', 'flan')": {
"c": 1.6583758241931825,
"a": 1.0527934808943984,
"b": 0.18172580896242374,
"mse": 0.0003657389558883888,
"n": 12
},
"('facebook/opt-6.7b', 'gigaword')": {
"c": 1.7214619335229422,
"a": 0.8501645122933167,
"b": 0.17553433626221163,
"mse": 0.0009104202480365186,
"n": 12
},
"('facebook/opt-6.7b', 'wikiword')": {
"c": 1.187363539503907,
"a": 1.4987769857971536,
"b": 0.18477038158174838,
"mse": 0.00047565093841222937,
"n": 12
},
"('google/mt5-base', 'flan')": {
"c": 1.583323676667614,
"a": 4.8827336950093425,
"b": 0.18333559931283497,
"mse": 0.008313217244721836,
"n": 12
},
"('google/mt5-base', 'gigaword')": {
"c": 2.0551210337798635,
"a": 2.4427896526776167,
"b": 0.17715640026992513,
"mse": 0.0038503896974422433,
"n": 12
},
"('google/mt5-base', 'wikiword')": {
"c": 0.9166373518594958,
"a": 5.753481901950828,
"b": 0.1851212657710932,
"mse": 0.010790900408354289,
"n": 12
},
"('google/mt5-large', 'flan')": {
"c": 1.4726763734534132,
"a": 3.4206192214207842,
"b": 0.18366684343063902,
"mse": 0.0045625286793028385,
"n": 12
},
"('google/mt5-large', 'gigaword')": {
"c": 1.905904106139712,
"a": 2.9000173722337452,
"b": 0.17523941384449754,
"mse": 0.008772571224160345,
"n": 12
},
"('google/mt5-large', 'wikiword')": {
"c": 1.0603143890912172,
"a": 4.462514518829667,
"b": 0.1860974601465054,
"mse": 0.003732898104365934,
"n": 12
},
"('gpt2', 'flan')": {
"c": 1.499554116407968,
"a": 9.023130844642328,
"b": 0.17585201323649569,
"mse": 0.08812713809218636,
"n": 12
},
"('gpt2', 'gigaword')": {
"c": 0.46457851812916895,
"a": 8.5707377758404,
"b": 0.1817309876468499,
"mse": 0.0740879352217463,
"n": 12
},
"('gpt2', 'wikiword')": {
"c": 1.2540722578538999,
"a": 4.36437086595112,
"b": 0.18391632922421508,
"mse": 0.007258446809941537,
"n": 12
},
"('t5-base', 'flan')": {
"c": 1.4471135954239664,
"a": 3.4628046853454775,
"b": 0.17403832928764254,
"mse": 0.006941254099221423,
"n": 12
},
"('t5-base', 'gigaword')": {
"c": 0.45810999439596933,
"a": 1.955554864376512,
"b": 0.18850642412886137,
"mse": 0.0001945310979496464,
"n": 12
},
"('t5-base', 'wikiword')": {
"c": 1.0907259526778699,
"a": 2.000945824397046,
"b": 0.18014107601933568,
"mse": 0.001727331273734328,
"n": 12
},
"('t5-small', 'flan')": {
"c": 1.6669969283388792,
"a": 3.9878042798753763,
"b": 0.1778290229582325,
"mse": 0.008247131168222439,
"n": 12
},
"('t5-small', 'gigaword')": {
"c": 0.5448930641881546,
"a": 2.0501491852544573,
"b": 0.18876257345373862,
"mse": 0.00029128786119094033,
"n": 12
},
"('t5-small', 'wikiword')": {
"c": 1.1972462755987916,
"a": 2.661511171116643,
"b": 0.17821433575400167,
"mse": 0.003410986481257676,
"n": 12
}
}
def _predict_loss(N: float, params: dict[str, float]) -> float:
# ensure positive N
if N <= 0 or not (N == N):
return float('nan')
c = float(params.get('c', 0.0))
a = float(params.get('a', 0.0))
b = float(params.get('b', 0.5))
return c + a * (N ** (-b))
def law(input_data: list[dict[str, float]], group: str) -> list[dict[str, float]]:
"""
Predicts output variables based on input variables according to a discovered scaling law.
Args:
input_data: A list of dictionaries, where each dictionary is a single data
point containing input variable names as keys and their
corresponding values.
group: The name of the experimental group for which to make predictions.
The functional form of the law must be the same for all groups,
but the constant parameters/coefficients can differ per group.
Returns:
A list of dictionaries, corresponding to the input_data list, with each
dictionary containing the predicted output variable(s).
"""
# Resolve parameters for the provided group key (supports tuple, list, or string)
params = _PARAMS.get(group)
if params is None:
# try string forms
keys_to_try = [str(group)]
try:
if not isinstance(group, (str, bytes)):
keys_to_try.append(str(tuple(group)))
except Exception:
pass
for k in keys_to_try:
params = _PARAMS.get(k)
if params is not None:
break
if params is None:
# fallback: average params across groups
if _PARAMS:
import statistics
cs = [v['c'] for v in _PARAMS.values()]
aas = [v['a'] for v in _PARAMS.values()]
bs = [v['b'] for v in _PARAMS.values()]
params = {'c': statistics.fmean(cs), 'a': statistics.fmean(aas), 'b': statistics.fmean(bs)}
else:
params = {'c': 0.0, 'a': 0.0, 'b': 0.5}
out = []
# support possible alternative key names
for row in input_data:
if "sft_data_size" in row:
N = float(row["sft_data_size"])
elif "data_size" in row:
N = float(row["data_size"])
elif "N" in row:
N = float(row["N"])
else:
# attempt to pull the first numeric value
N = None
for k, v in row.items():
if isinstance(v, (int, float)):
N = float(v)
break
if N is None:
N = float('nan')
pred = _predict_loss(N, params)
out.append({'sft_loss': float(pred)})
return out