"""
Module implementing the discovered scaling law for supervised fine-tuning loss.
"""
def law(input_data: list[dict[str, float]], group: str) -> list[dict[str, float]]:
"""
Predicts output variables based on input variables according to a discovered scaling law.
Args:
input_data: A list of dictionaries, where each dictionary is a single data
point containing input variable names as keys and their
corresponding values.
group: The name of the experimental group for which to make predictions.
The functional form of the law is the same for all groups,
but the constant parameters/coefficients differ per group.
Returns:
A list of dictionaries, corresponding to the input_data list, with each
dictionary containing the predicted output variable(s) under key 'sft_loss'.
"""
# Parameters per group: L_inf + a * N^{-alpha}
params = {
"('MBZUAI/LaMini-GPT-124M', 'flan')": {'L_inf': 4.88891522852415e-13, 'a': 6.973889417936588, 'alpha': 0.08134160664016567},
"('MBZUAI/LaMini-GPT-124M', 'gigaword')": {'L_inf': 1.2549475345759155e-15, 'a': 5.923742908062525, 'alpha': 0.10520859443461071},
"('MBZUAI/LaMini-GPT-124M', 'wikiword')": {'L_inf': 6.73534819503915e-22, 'a': 3.7815959443424148, 'alpha': 0.06407311292143066},
"('MBZUAI/LaMini-GPT-774M', 'flan')": {'L_inf': 3.0986536858492428e-15, 'a': 5.416194447728044, 'alpha': 0.07203582303793636},
"('MBZUAI/LaMini-GPT-774M', 'gigaword')": {'L_inf': 1.5568069794903192e-15, 'a': 5.350693782433369, 'alpha': 0.10952387398189671},
"('MBZUAI/LaMini-GPT-774M', 'wikiword')": {'L_inf': 1.5817591620080455e-18, 'a': 2.8815067646791404, 'alpha': 0.05390255281176658},
"('cerebras/Cerebras-GPT-1.3B', 'flan')": {'L_inf': 4.660261124868628e-19, 'a': 3.72668747103668, 'alpha': 0.05133622852901864},
"('cerebras/Cerebras-GPT-1.3B', 'gigaword')": {'L_inf': 1.4572929400202255e-21, 'a': 4.601802095259647, 'alpha': 0.08937942187812241},
"('cerebras/Cerebras-GPT-1.3B', 'wikiword')": {'L_inf': 5.166756221644035e-20, 'a': 3.1670522055347803, 'alpha': 0.05010482255812582},
"('cerebras/Cerebras-GPT-256M', 'flan')": {'L_inf': 5.681581108258014e-16, 'a': 4.497775771102446, 'alpha': 0.049197297920243487},
"('cerebras/Cerebras-GPT-256M', 'gigaword')": {'L_inf': 5.792914647146667e-14, 'a': 5.237428498166979, 'alpha': 0.10004126991906952},
"('cerebras/Cerebras-GPT-256M', 'wikiword')": {'L_inf': 1.3546553850129447e-24, 'a': 4.456019626802189, 'alpha': 0.06905741523821077},
"('facebook/bart-base', 'flan')": {'L_inf': 4.485165927406458e-21, 'a': 6.796594796854344, 'alpha': 0.08550170315119676},
"('facebook/bart-base', 'gigaword')": {'L_inf': 2.8011504261522693e-13, 'a': 7.594455935790706, 'alpha': 0.13330962892884293},
"('facebook/bart-base', 'wikiword')": {'L_inf': 0.2505157704866666, 'a': 5.861961168558081, 'alpha': 0.12018562279904178},
"('facebook/bart-large', 'flan')": {'L_inf': 8.90097148338901e-13, 'a': 5.121314574164879, 'alpha': 0.07407830522184079},
"('facebook/bart-large', 'gigaword')": {'L_inf': 7.348996681476381e-17, 'a': 7.683922260497716, 'alpha': 0.14310663430953338},
"('facebook/bart-large', 'wikiword')": {'L_inf': 0.7814633219457406, 'a': 2.620750838514362, 'alpha': 0.11520371495174195},
"('facebook/opt-1.3b', 'flan')": {'L_inf': 9.451184733271898e-25, 'a': 3.2171035320467642, 'alpha': 0.049050713966180944},
"('facebook/opt-1.3b', 'gigaword')": {'L_inf': 4.3610823050173863e-20, 'a': 5.229994584305667, 'alpha': 0.10805297945869073},
"('facebook/opt-1.3b', 'wikiword')": {'L_inf': 2.9031538655738537e-16, 'a': 2.3491291175103908, 'alpha': 0.041762877333812815},
"('facebook/opt-350m', 'flan')": {'L_inf': 7.937175363375485e-15, 'a': 4.468174390456382, 'alpha': 0.057611663183111395},
"('facebook/opt-350m', 'gigaword')": {'L_inf': 2.294599340497756e-16, 'a': 6.499202040847061, 'alpha': 0.12331848582705905},
"('facebook/opt-350m', 'wikiword')": {'L_inf': 9.338143776217542e-10, 'a': 3.2418388198175885, 'alpha': 0.055463997801027023},
"('facebook/opt-6.7b', 'flan')": {'L_inf': 8.14082711544948e-10, 'a': 2.2339677911390265, 'alpha': 0.019148365255527804},
"('facebook/opt-6.7b', 'gigaword')": {'L_inf': 6.4850563005254914e-21, 'a': 2.1770799489607238, 'alpha': 0.014498526959777614},
"('facebook/opt-6.7b', 'wikiword')": {'L_inf': 0.26936689227561367, 'a': 1.788070219787539, 'alpha': 0.04220960709323751},
"('google/mt5-base', 'flan')": {'L_inf': 2.526590900644073e-12, 'a': 4.571149010287784, 'alpha': 0.06361718985883816},
"('google/mt5-base', 'gigaword')": {'L_inf': 3.114373774345398e-18, 'a': 3.439603746623132, 'alpha': 0.03165062217213554},
"('google/mt5-base', 'wikiword')": {'L_inf': 4.10877531839249e-14, 'a': 4.734453353182449, 'alpha': 0.09261262630404997},
"('google/mt5-large', 'flan')": {'L_inf': 7.672745382307123e-19, 'a': 3.492798565293053, 'alpha': 0.05281181441669174},
"('google/mt5-large', 'gigaword')": {'L_inf': 5.873110659782422e-19, 'a': 3.549257797168276, 'alpha': 0.03688126478621837},
"('google/mt5-large', 'wikiword')": {'L_inf': 6.632774773255506e-16, 'a': 3.926297389250783, 'alpha': 0.07816268985536855},
"('gpt2', 'flan')": {'L_inf': 5.700579218298801e-13, 'a': 7.1983961879890614, 'alpha': 0.08172041645085604},
"('gpt2', 'gigaword')": {'L_inf': 5.449791195262194e-16, 'a': 6.339048122442959, 'alpha': 0.11725592134070295},
"('gpt2', 'wikiword')": {'L_inf': 3.929261588051236e-25, 'a': 3.946267218928482, 'alpha': 0.06811657162760475},
"('t5-base', 'flan')": {'L_inf': 4.841398974031212e-21, 'a': 3.543052166590739, 'alpha': 0.052255117939287866},
"('t5-base', 'gigaword')": {'L_inf': 0.41674098898230477, 'a': 1.8233793996455916, 'alpha': 0.16745997442401883},
"('t5-base', 'wikiword')": {'L_inf': 2.264923914746759e-22, 'a': 2.2598720880828873, 'alpha': 0.04457386963843646},
"('t5-small', 'flan')": {'L_inf': 5.731082823765475e-18, 'a': 4.0524929725844165, 'alpha': 0.05269102972550974},
"('t5-small', 'gigaword')": {'L_inf': 0.40091815004312686, 'a': 1.7757420094751855, 'alpha': 0.13439805928221435},
"('t5-small', 'wikiword')": {'L_inf': 1.5765442089801293e-20, 'a': 2.7905737195023983, 'alpha': 0.0507995309961754},
}
if group not in params:
raise ValueError(f"Unknown group: {group}")
p = params[group]
results = []
for d in input_data:
N = d.get('sft_data_size')
# Apply scaling law: loss = L_inf + a * N^{-alpha}
loss = p['L_inf'] + p['a'] * (N ** (-p['alpha']))
results.append({'sft_loss': loss})
return results