from __future__ import annotations
import math
from typing import List, Dict
# Learned parameters for each group for the scaling law:
# sft_loss = c + a * (sft_data_size + x0) ** (-b)
_PARAMS = {
"('MBZUAI/LaMini-GPT-124M', 'flan')": {'c': 1.564269587, 'a': 87.98498619, 'b': 0.3824763366, 'x0': 10000},
"('MBZUAI/LaMini-GPT-124M', 'gigaword')": {'c': 0.4837739236, 'a': 59.71925558, 'b': 0.3416696269, 'x0': 10000},
"('MBZUAI/LaMini-GPT-124M', 'wikiword')": {'c': 1.048831887, 'a': 5.997674903, 'b': 0.1826307524, 'x0': 1584.893192},
"('MBZUAI/LaMini-GPT-774M', 'flan')": {'c': 1.307240186, 'a': 41.76904713, 'b': 0.3242540963, 'x0': 10000},
"('MBZUAI/LaMini-GPT-774M', 'gigaword')": {'c': 0.6260466467, 'a': 104.8492498, 'b': 0.4259772459, 'x0': 10000},
"('MBZUAI/LaMini-GPT-774M', 'wikiword')": {'c': 0.846386623, 'a': 2.652869033, 'b': 0.116385829, 'x0': 251.1886431},
"('cerebras/Cerebras-GPT-1.3B', 'flan')": {'c': 1.371608086, 'a': 5.177573316, 'b': 0.1776235759, 'x0': 1584.893192},
"('cerebras/Cerebras-GPT-1.3B', 'gigaword')": {'c': 0.1373163055, 'a': 7.12887961, 'b': 0.137700517, 'x0': 1584.893192},
"('cerebras/Cerebras-GPT-1.3B', 'wikiword')": {'c': 1.084408809, 'a': 3.972975118, 'b': 0.1537142579, 'x0': 1584.893192},
"('cerebras/Cerebras-GPT-256M', 'flan')": {'c': 1.265501879, 'a': 5.42406848, 'b': 0.127521399, 'x0': 1584.893192},
"('cerebras/Cerebras-GPT-256M', 'gigaword')": {'c': 0, 'a': 8.979116797, 'b': 0.1512820373, 'x0': 1584.893192},
"('cerebras/Cerebras-GPT-256M', 'wikiword')": {'c': 0.7785420197, 'a': 4.614099641, 'b': 0.1157775285, 'x0': 251.1886431},
"('facebook/bart-base', 'flan')": {'c': 0.3645941746, 'a': 10.54850589, 'b': 0.1400599591, 'x0': 1584.893192},
"('facebook/bart-base', 'gigaword')": {'c': 0.8022925003, 'a': 558.1446562, 'b': 0.5859135735, 'x0': 10000},
"('facebook/bart-base', 'wikiword')": {'c': 0.9874219716, 'a': 8.509733939, 'b': 0.2138652817, 'x0': 251.1886431},
"('facebook/bart-large', 'flan')": {'c': 0.6377771789, 'a': 5.525963949, 'b': 0.1123388906, 'x0': 251.1886431},
"('facebook/bart-large', 'gigaword')": {'c': 0.8135473126, 'a': 1721.147572, 'b': 0.7131843465, 'x0': 10000},
"('facebook/bart-large', 'wikiword')": {'c': 0.7826410482, 'a': 2.626828832, 'b': 0.1156186513, 'x0': 0},
"('facebook/opt-1.3b', 'flan')": {'c': 1.268665971, 'a': 4.437107393, 'b': 0.182582469, 'x0': 1584.893192},
"('facebook/opt-1.3b', 'gigaword')": {'c': 0.1162159713, 'a': 9.303401005, 'b': 0.1692837743, 'x0': 1584.893192},
"('facebook/opt-1.3b', 'wikiword')": {'c': 0.9604445677, 'a': 1.854993681, 'b': 0.1160700451, 'x0': 251.1886431},
"('facebook/opt-350m', 'flan')": {'c': 0.9725591969, 'a': 5.931693964, 'b': 0.1327214077, 'x0': 1584.893192},
"('facebook/opt-350m', 'gigaword')": {'c': 0, 'a': 12.90265068, 'b': 0.189106158, 'x0': 1584.893192},
"('facebook/opt-350m', 'wikiword')": {'c': 0.8957143243, 'a': 2.994444891, 'b': 0.1136282614, 'x0': 251.1886431},
"('facebook/opt-6.7b', 'flan')": {'c': 1.517869515, 'a': 0.9931489165, 'b': 0.1126229971, 'x0': 251.1886431},
"('facebook/opt-6.7b', 'gigaword')": {'c': 1.723087933, 'a': 7.567498663, 'b': 0.3652711781, 'x0': 10000},
"('facebook/opt-6.7b', 'wikiword')": {'c': 0.9976291701, 'a': 1.409723211, 'b': 0.1163785758, 'x0': 251.1886431},
"('google/mt5-base', 'flan')": {'c': 0.9909552392, 'a': 4.597987709, 'b': 0.1175021959, 'x0': 251.1886431},
"('google/mt5-base', 'gigaword')": {'c': 1.738144527, 'a': 3.077069792, 'b': 0.1368078084, 'x0': 1584.893192},
"('google/mt5-base', 'wikiword')": {'c': 0.2199407554, 'a': 5.398032721, 'b': 0.1182095443, 'x0': 251.1886431},
"('google/mt5-large', 'flan')": {'c': 1.434152687, 'a': 6.288948252, 'b': 0.2299434479, 'x0': 1584.893192},
"('google/mt5-large', 'gigaword')": {'c': 1.900891986, 'a': 23.65735527, 'b': 0.3546820123, 'x0': 10000},
"('google/mt5-large', 'wikiword')": {'c': 0.5457049924, 'a': 4.193880429, 'b': 0.1210634752, 'x0': 251.1886431},
"('gpt2', 'flan')": {'c': 1.490922453, 'a': 75.32900687, 'b': 0.3578463194, 'x0': 10000},
"('gpt2', 'gigaword')": {'c': 0.8476347336, 'a': 301.256379, 'b': 0.5329780156, 'x0': 10000},
"('gpt2', 'wikiword')": {'c': 1.203279469, 'a': 7.999141793, 'b': 0.2296289159, 'x0': 1584.893192},
"('t5-base', 'flan')": {'c': 0.9738756683, 'a': 4.331690322, 'b': 0.132221776, 'x0': 1584.893192},
"('t5-base', 'gigaword')": {'c': 0.4284101849, 'a': 1.864584477, 'b': 0.1734356468, 'x0': 6.30957344},
"('t5-base', 'wikiword')": {'c': 1.001935901, 'a': 3.053731805, 'b': 0.1909571108, 'x0': 1584.893192},
"('t5-small', 'flan')": {'c': 1.142977987, 'a': 4.981709663, 'b': 0.135803553, 'x0': 1584.893192},
"('t5-small', 'gigaword')": {'c': 0.5884135339, 'a': 2.844380432, 'b': 0.2363651475, 'x0': 251.1886431},
"('t5-small', 'wikiword')": {'c': 1.083657457, 'a': 4.090656387, 'b': 0.1910050169, 'x0': 1584.893192},
}
# Fallback parameters (mean across groups) used if an unseen group is provided
_FALLBACK = {"c":0.9272686752852628,"a":74.88410175959388,"b":0.22081211328758815,"x0":3132.110732274286}
def _predict_one(n: float, p: dict[str, float]) -> float:
n = float(n)
c = float(p.get("c", 0.0))
a = float(p.get("a", 1.0))
b = float(p.get("b", 0.5))
x0 = float(p.get("x0", 0.0))
# Guard for non-positive n: treat as 0
if not math.isfinite(n) or n < 0:
n = 0.0
return c + a * (n + x0) ** (-b)
def law(input_data: list[dict[str, float]], group: str) -> list[dict[str, float]]:
"""
Predicts output variables based on input variables according to a discovered scaling law.
Args:
input_data: A list of dictionaries, where each dictionary is a single data
point containing input variable names as keys and their
corresponding values.
group: The name of the experimental group for which to make predictions.
The functional form of the law must be the same for all groups,
but the constant parameters/coefficients can differ per group.
Returns:
A list of dictionaries, corresponding to the input_data list, with each
dictionary containing the predicted output variable(s).
"""
params = _PARAMS.get(group, _FALLBACK)
out: list[dict[str, float]] = []
for item in input_data:
# Expect 'sft_data_size' as the driver variable
n = item.get("sft_data_size")
if n is None:
# Try common aliases just in case
n = item.get("N", item.get("n", 0.0))
yhat = _predict_one(n, params)
out.append({"sft_loss": float(yhat)})
return out