from __future__ import annotations
from typing import Dict, List
# Common functional form across all groups:
# sft_loss(N) = L_inf + A * (N + N0) ** (-alpha)
# where N is `sft_data_size` and parameters (L_inf, A, alpha, N0) vary by `group`.
PARAMS: Dict[str, Dict[str, float]] = {
"('MBZUAI/LaMini-GPT-124M', 'flan')": {'L_inf': 1.865671228941949e-17, 'A': 12.637678561446139, 'alpha': 0.13564240302792172, 'N0': 3172.8349547867774},
"('MBZUAI/LaMini-GPT-124M', 'gigaword')": {'L_inf': 0.6238498800785734, 'A': 99.99999999999999, 'alpha': 0.3976078224298724, 'N0': 11558.491067541954},
"('MBZUAI/LaMini-GPT-124M', 'wikiword')": {'L_inf': 5.063925056378967e-15, 'A': 4.23348975814935, 'alpha': 0.07460412111294643, 'N0': 436.6866783168706},
"('MBZUAI/LaMini-GPT-774M', 'flan')": {'L_inf': 9.10578307107794e-19, 'A': 8.92224005340684, 'alpha': 0.11739594489898576, 'N0': 3069.407127001592},
"('MBZUAI/LaMini-GPT-774M', 'gigaword')": {'L_inf': 0.49159473249601304, 'A': 53.72380342046997, 'alpha': 0.35384913142132957, 'N0': 8208.078352425737},
"('MBZUAI/LaMini-GPT-774M', 'wikiword')": {'L_inf': 1.2593622265895715e-07, 'A': 2.9896484714861424, 'alpha': 0.05735309742066091, 'N0': 140.71022280680234},
"('cerebras/Cerebras-GPT-1.3B', 'flan')": {'L_inf': 1.1318279234338576e-18, 'A': 4.062878459420533, 'alpha': 0.05934500761506854, 'N0': 426.03416018746805},
"('cerebras/Cerebras-GPT-1.3B', 'gigaword')": {'L_inf': 2.3927299190951424e-21, 'A': 6.336196488869017, 'alpha': 0.1192014399950389, 'N0': 1084.1232942926576},
"('cerebras/Cerebras-GPT-1.3B', 'wikiword')": {'L_inf': 9.220157697788088e-20, 'A': 3.4101310251993695, 'alpha': 0.056959891171103054, 'N0': 363.70595860114133},
"('cerebras/Cerebras-GPT-256M', 'flan')": {'L_inf': 2.315142285734426e-22, 'A': 5.319365673488941, 'alpha': 0.06450038022371594, 'N0': 1162.8602806462848},
"('cerebras/Cerebras-GPT-256M', 'gigaword')": {'L_inf': 1.01404483163551e-14, 'A': 10.7925493382822, 'alpha': 0.16678612735396567, 'N0': 2909.744466994437},
"('cerebras/Cerebras-GPT-256M', 'wikiword')": {'L_inf': 1.6675446178776843e-19, 'A': 4.756329316561344, 'alpha': 0.07520633987592405, 'N0': 197.06918542242636},
"('facebook/bart-base', 'flan')": {'L_inf': 3.4024201572921337e-19, 'A': 9.466998707085137, 'alpha': 0.1163303938056367, 'N0': 1218.0973728427302},
"('facebook/bart-base', 'gigaword')": {'L_inf': 0.569762177448528, 'A': 99.9999998931003, 'alpha': 0.4093787678794461, 'N0': 6244.642597957972},
"('facebook/bart-base', 'wikiword')": {'L_inf': 1.2241220338350347, 'A': 14.5986109269959, 'alpha': 0.29682572736757595, 'N0': 550.5089063545902},
"('facebook/bart-large', 'flan')": {'L_inf': 1.1510770614726884e-17, 'A': 5.611410375270038, 'alpha': 0.0826931330009914, 'N0': 269.42264002978004},
"('facebook/bart-large', 'gigaword')": {'L_inf': 0.4362137935788755, 'A': 61.030730385354516, 'alpha': 0.36845048401898794, 'N0': 4178.040636016911},
"('facebook/bart-large', 'wikiword')": {'L_inf': 0.7814640933847434, 'A': 2.6207508142290563, 'alpha': 0.11520380951312024, 'N0': 7.81642673397928e-11},
"('facebook/opt-1.3b', 'flan')": {'L_inf': 3.0801677443604495e-12, 'A': 3.437167448516286, 'alpha': 0.05519316053319131, 'N0': 323.52209754302413},
"('facebook/opt-1.3b', 'gigaword')": {'L_inf': 0.303325908352575, 'A': 10.781976113943252, 'alpha': 0.19556325386045303, 'N0': 1844.5375817073113},
"('facebook/opt-1.3b', 'wikiword')": {'L_inf': 6.849950839473154e-13, 'A': 2.371033103373906, 'alpha': 0.0426291470638209, 'N0': 42.52989450771065},
"('facebook/opt-350m', 'flan')": {'L_inf': 1.0439604500936763e-21, 'A': 5.627459766341364, 'alpha': 0.07868013844546193, 'N0': 1427.3103646514269},
"('facebook/opt-350m', 'gigaword')": {'L_inf': 0.3170636936481373, 'A': 21.45165635532035, 'alpha': 0.2548387620240815, 'N0': 2967.5199471791925},
"('facebook/opt-350m', 'wikiword')": {'L_inf': 7.435230193328557e-18, 'A': 3.257815231163734, 'alpha': 0.055927008001466834, 'N0': 15.872263908578827},
"('facebook/opt-6.7b', 'flan')": {'L_inf': 1.7464718748726132e-09, 'A': 2.239883130424448, 'alpha': 0.019392187633072766, 'N0': 27.45074276948627},
"('facebook/opt-6.7b', 'gigaword')": {'L_inf': 1.6339485409447412, 'A': 1.852648048421951, 'alpha': 0.19215155224167946, 'N0': 5578.387896916012},
"('facebook/opt-6.7b', 'wikiword')": {'L_inf': 0.8797137795917916, 'A': 1.3801606877059753, 'alpha': 0.09031133705708351, 'N0': 150.715299869533},
"('google/mt5-base', 'flan')": {'L_inf': 1.0369199206942231e-16, 'A': 4.936124715180229, 'alpha': 0.07082556462170639, 'N0': 268.26577139087215},
"('google/mt5-base', 'gigaword')": {'L_inf': 2.993452899495181e-20, 'A': 3.657207489347324, 'alpha': 0.037261242285363386, 'N0': 549.53675753708},
"('google/mt5-base', 'wikiword')": {'L_inf': 8.178379339595827e-17, 'A': 5.558671475357777, 'alpha': 0.10787587390263807, 'N0': 388.26365277913226},
"('google/mt5-large', 'flan')": {'L_inf': 2.0351395227477922e-12, 'A': 3.7361033461210726, 'alpha': 0.059085413639518115, 'N0': 296.79195384894666},
"('google/mt5-large', 'gigaword')": {'L_inf': 1.7285593853941235e-20, 'A': 4.301700747580993, 'alpha': 0.054175354177230746, 'N0': 2255.1405984927396},
"('google/mt5-large', 'wikiword')": {'L_inf': 3.385679168082505e-20, 'A': 4.070364998240034, 'alpha': 0.08159499527141832, 'N0': 84.73574253453266},
"('gpt2', 'flan')": {'L_inf': 4.015340729288829e-15, 'A': 14.34124705639774, 'alpha': 0.14433103359274688, 'N0': 3987.9597869657864},
"('gpt2', 'gigaword')": {'L_inf': 0.47259432236704396, 'A': 41.02611997599669, 'alpha': 0.3190891985857585, 'N0': 5570.9100030079235},
"('gpt2', 'wikiword')": {'L_inf': 1.8048515222229452e-22, 'A': 4.388935478788763, 'alpha': 0.0780881235127983, 'N0': 365.9997120466946},
"('t5-base', 'flan')": {'L_inf': 8.028632468328126e-18, 'A': 3.8842431380507225, 'alpha': 0.06076610065498959, 'N0': 454.6947810694002},
"('t5-base', 'gigaword')": {'L_inf': 0.416740982867647, 'A': 1.8233793878342568, 'alpha': 0.16745997189052603, 'N0': 7.833754721840204e-09},
"('t5-base', 'wikiword')": {'L_inf': 1.7554230404581307e-14, 'A': 2.3917550734396555, 'alpha': 0.049831240630992694, 'N0': 303.9980112153414},
"('t5-small', 'flan')": {'L_inf': 2.006136285290001e-21, 'A': 4.428866599546551, 'alpha': 0.06092213223812954, 'N0': 428.3855452155827},
"('t5-small', 'gigaword')": {'L_inf': 0.5585503700141525, 'A': 2.424820616633481, 'alpha': 0.20909831661111153, 'N0': 173.82806592094695},
"('t5-small', 'wikiword')": {'L_inf': 1.0419309334128787e-12, 'A': 3.0054681324705808, 'alpha': 0.057697528448694614, 'N0': 352.6602281915071},
}
# Fallback parameters (robust median-ish typical behavior) if an unknown group is requested.
FALLBACK = {'L_inf': 0.5, 'A': 5.0, 'alpha': 0.10, 'N0': 300.0}
def _predict_loss(n: float, p: Dict[str, float]) -> float:
# Guard against negative or pathological inputs
n = max(0.0, float(n))
L_inf = float(p['L_inf'])
A = float(p['A'])
alpha = float(p['alpha'])
N0 = float(p['N0'])
return L_inf + A * (n + N0) ** (-alpha)
def law(input_data: List[Dict[str, float]], group: str) -> List[Dict[str, float]]:
"""
Predicts output variables based on input variables according to a discovered scaling law.
Args:
input_data: A list of dictionaries, where each dictionary is a single data
point containing input variable names as keys and their
corresponding values.
group: The name of the experimental group for which to make predictions.
The functional form of the law must be the same for all groups,
but the constant parameters/coefficients can differ per group.
Returns:
A list of dictionaries, corresponding to the input_data list, with each
dictionary containing the predicted output variable(s).
"""
params = PARAMS.get(group, FALLBACK)
out: List[Dict[str, float]] = []
for row in input_data:
n = row.get('sft_data_size')
if n is None:
raise KeyError("Each input row must include 'sft_data_size'.")
pred = _predict_loss(n, params)
out.append({'sft_loss': float(pred)})
return out