from __future__ import annotations
from typing import List, Dict
# Parameters fitted per experimental group for the scaling law
# L(N) = L_inf + A * (N + N0)^(-alpha)
GROUP_PARAMS = {
"('MBZUAI/LaMini-GPT-124M', 'flan')": (1.74334372845e-14, 12.6376966889, 0.135642527036, 3172.84813899),
"('MBZUAI/LaMini-GPT-124M', 'gigaword')": (0.693703651144, 138.472285153, 0.431968782432, 12511.858243),
"('MBZUAI/LaMini-GPT-124M', 'wikiword')": (1.01586399934e-16, 4.2334921402, 0.0746041722245, 436.68969118),
"('MBZUAI/LaMini-GPT-774M', 'flan')": (1.19409314659e-18, 8.92224180358, 0.117395961813, 3069.40919683),
"('MBZUAI/LaMini-GPT-774M', 'gigaword')": (0.491595039277, 53.7238611705, 0.353849253703, 8208.08072935),
"('MBZUAI/LaMini-GPT-774M', 'wikiword')": (5.27990777079e-16, 2.98964855216, 0.0573530926488, 140.710189355),
"('cerebras/Cerebras-GPT-1.3B', 'flan')": (3.76889628326e-16, 4.06287843965, 0.0593450071786, 426.034125852),
"('cerebras/Cerebras-GPT-1.3B', 'gigaword')": (6.18853897564e-19, 6.33621255116, 0.119201667319, 1084.13661129),
"('cerebras/Cerebras-GPT-1.3B', 'wikiword')": (8.67298183717e-15, 3.41013139136, 0.0569598662977, 363.706821923),
"('cerebras/Cerebras-GPT-256M', 'flan')": (5.61099380572e-23, 4.09129228086, 0.0655956807586, 381.324214495),
"('cerebras/Cerebras-GPT-256M', 'gigaword')": (2.33071354039e-21, 6.20496986049, 0.0898142124765, 421.971295556),
"('cerebras/Cerebras-GPT-256M', 'wikiword')": (7.17491357681e-12, 2.97543402284, 0.10948596302, 1540.59235749),
"('facebook/bart-base', 'flan')": (4.59171206171e-23, 3.92826799844, 0.0572921215978, 513.889289197),
"('facebook/bart-base', 'gigaword')": (0.331145263898, 1.80249328365, 0.161681097935, 1.38996740228e-08),
"('facebook/bart-base', 'wikiword')": (7.54240118853e-08, 2.40134088564, 0.046208695099, 368.565361701),
"('facebook/bart-large', 'flan')": (5.08695784056e-19, 3.85220096633, 0.053905137389, 522.232192014),
"('facebook/bart-large', 'gigaword')": (0.333615694269, 2.10767061513, 0.152700192794, 2.67294535273e-08),
"('facebook/bart-large', 'wikiword')": (0.0102249525968, 2.66899639933, 0.0733424237895, 125.118536989),
"('google/flan-t5-base', 'flan')": (4.91247757583e-23, 3.60117304024, 0.0572465810795, 417.875313569),
"('google/flan-t5-base', 'gigaword')": (0.405005274224, 1.61373265073, 0.164013157084, 3.1548076111e-09),
"('google/flan-t5-base', 'wikiword')": (9.15667282415e-06, 2.32680485136, 0.0503996647358, 283.080905921),
"('google/flan-t5-small', 'flan')": (1.00714948887e-18, 3.83772842519, 0.0591848781206, 455.282652859),
"('google/flan-t5-small', 'gigaword')": (0.535288121688, 2.08203051743, 0.178595557119, 22.4200612776),
"('google/flan-t5-small', 'wikiword')": (1.03689287158e-14, 2.85083455163, 0.056333131327, 301.738465789),
"('google/gemma-1.1-2b-it', 'flan')": (6.8915143895e-21, 3.00848271538, 0.0480777398665, 390.795415165),
"('google/gemma-1.1-2b-it', 'gigaword')": (0.625782070821, 1.82234052821, 0.175295164625, 1.42379766592e-08),
"('google/gemma-1.1-2b-it', 'wikiword')": (2.05488645576e-08, 2.04820492761, 0.0343357892473, 304.231733022),
"('HuggingFaceH4/zephyr-7b-alpha', 'flan')": (4.9049381785e-22, 2.43999566801, 0.0515771748292, 215.32550852),
"('HuggingFaceH4/zephyr-7b-alpha', 'gigaword')": (0.663107576466, 2.05155751097, 0.187721428568, 26.8042322803),
"('HuggingFaceH4/zephyr-7b-alpha', 'wikiword')": (4.12506384686e-15, 1.75811857156, 0.04720378107, 278.035272188),
"('HuggingFaceH4/zephyr-7b-beta', 'flan')": (3.09249230834e-20, 1.99774222836, 0.0462512339271, 212.068394081),
"('HuggingFaceH4/zephyr-7b-beta', 'gigaword')": (0.645108253463, 1.31823579619, 0.185613813767, 1.64484516925e-07),
"('HuggingFaceH4/zephyr-7b-beta', 'wikiword')": (3.41599139834e-18, 1.50808886401, 0.0431953430781, 270.828819886),
"('MBZUAI/LaMini-GPT-124M', 'flan+synthetic')": (5.48508537431e-09, 3.44038957741, 0.10143885973, 522.878784361),
"('MBZUAI/LaMini-GPT-124M', 'gigaword+synthetic')": (0.6503504776, 43.3219835398, 0.388166622966, 6800.69906577),
"('MBZUAI/LaMini-GPT-124M', 'wikiword+synthetic')": (7.2807427126e-09, 2.37643665333, 0.0904549693201, 583.1756373),
"('MBZUAI/LaMini-GPT-774M', 'flan+synthetic')": (3.14075461673e-10, 3.34192480141, 0.0956885199672, 644.350510751),
"('MBZUAI/LaMini-GPT-774M', 'gigaword+synthetic')": (0.487120603132, 25.5290146884, 0.295662400974, 4101.34778253),
"('MBZUAI/LaMini-GPT-774M', 'wikiword+synthetic')": (1.61549400971e-08, 2.19369617312, 0.0864794164689, 589.042812831),
"('meta-llama/Llama-2-7b-chat-hf', 'flan')": (3.82222214223e-10, 2.17654619575, 0.0736122258328, 316.310451114),
"('meta-llama/Llama-2-7b-chat-hf', 'gigaword')": (0.65265599579, 1.61221334699, 0.203963866387, 7.77256658209e-08),
"('meta-llama/Llama-2-7b-chat-hf', 'wikiword')": (1.10820500056e-16, 1.42413350177, 0.0494370037209, 181.81689969),
"('openchat/openchat_3.5', 'flan')": (7.48101007959e-13, 2.14041757993, 0.0610714316899, 291.439271718),
"('openchat/openchat_3.5', 'gigaword')": (0.442685808441, 1.59285815147, 0.190752780851, 62.6334333168),
"('openchat/openchat_3.5', 'wikiword')": (1.608493018e-14, 1.44339217538, 0.0506461652539, 231.637812601),
"('Qwen/Qwen1.5-1.8B-Chat', 'flan')": (6.27129779661e-22, 2.42847553091, 0.057062250832, 313.402069889),
"('Qwen/Qwen1.5-1.8B-Chat', 'gigaword')": (0.672573369247, 1.68982128645, 0.195871473078, 1.58312141503e-07),
"('Qwen/Qwen1.5-1.8B-Chat', 'wikiword')": (1.15387868584e-16, 1.78180769947, 0.0558031745887, 246.147766024),
"('Qwen/Qwen1.5-7B-Chat', 'flan')": (2.41403848954e-21, 1.7337576363, 0.0522215109255, 287.46828319),
"('Qwen/Qwen1.5-7B-Chat', 'gigaword')": (0.701829784636, 1.54187766402, 0.206394932847, 3.91240722118e-08),
"('Qwen/Qwen1.5-7B-Chat', 'wikiword')": (5.58834583656e-19, 1.35838705992, 0.0471114172229, 255.717687844),
"('tiiuae/falcon-1b', 'flan')": (2.93496793134e-17, 4.38619638046, 0.0803432225412, 635.63611150),
"('tiiuae/falcon-1b', 'gigaword')": (0.642779525838, 7.68985835857, 0.246590363741, 1906.41074333),
"('tiiuae/falcon-1b', 'wikiword')": (3.18649499919e-12, 3.10653686584, 0.103057831142, 1573.84972758),
"('tiiuae/falcon-7b-instruct', 'flan')": (2.60111608485e-16, 2.22355106974, 0.0722324800102, 397.155408139),
"('tiiuae/falcon-7b-instruct', 'gigaword')": (0.669375089848, 1.61392768813, 0.203131910042, 2.41640431048e-08),
"('tiiuae/falcon-7b-instruct', 'wikiword')": (1.6813344452e-18, 1.29755740218, 0.0449371852927, 230.006648709),
"('t5-base', 'flan')": (1.20686838657e-20, 3.57216973557, 0.0528291592115, 454.455171846),
"('t5-base', 'gigaword')": (0.416740988169, 1.82337939819, 0.167459974094, 5.3249802882e-09),
"('t5-base', 'wikiword')": (8.88575837592e-06, 2.39174846722, 0.0498315754901, 304.000915976),
"('t5-small', 'flan')": (1.04548316353e-21, 4.42886795655, 0.0609221599197, 428.387583993),
"('t5-small', 'gigaword')": (0.558546921255, 2.42478473544, 0.209095628426, 173.820139152),
"('t5-small', 'wikiword')": (1.89469511102e-15, 3.00546805873, 0.0576975262125, 352.660073561),
}
# Robust fallback parameters (median across groups), used if an unknown group is requested
FALLBACK_PARAMS = (4.73642516861e-15, 4.40890174984, 0.0821440711414, 432.538637586)
def _predict_loss(N: float, params: tuple[float, float, float, float]) -> float:
L0, A, alpha, N0 = params
x = float(N) + float(N0)
if x < 1e-9:
x = 1e-9
return float(L0) + float(A) * (x ** (-float(alpha)))
def law(input_data: List[Dict[str, float]], group: str) -> List[Dict[str, float]]:
"""
Predicts output variables based on input variables according to a discovered scaling law.
Args:
input_data: A list of dictionaries, where each dictionary is a single data
point containing input variable names as keys and their
corresponding values.
group: The name of the experimental group for which to make predictions.
The functional form of the law must be the same for all groups,
but the constant parameters/coefficients can differ per group.
Returns:
A list of dictionaries, corresponding to the input_data list, with each
dictionary containing the predicted output variable(s).
"""
params = GROUP_PARAMS.get(group, FALLBACK_PARAMS)
return [{"sft_loss": _predict_loss(row.get("sft_data_size", 0.0), params)} for row in input_data]