# Autogenerated scaling law implementation
from __future__ import annotations
from typing import List, Dict
import math
import ast
# Parameters fitted from training data
_PARAMS_BY_GROUP: Dict[str, Dict[str, float]] = {
"('MBZUAI/LaMini-GPT-124M', 'flan')": {
"A": 6.551027034471107,
"L_inf": 1.0512777381812202,
"alpha": 0.11745762711864405,
"sse": 0.4808298153556973
},
"('MBZUAI/LaMini-GPT-124M', 'gigaword')": {
"A": 5.765933611487401,
"L_inf": 0.10943580893255611,
"alpha": 0.10745762711864404,
"sse": 0.6245798014382321
},
"('MBZUAI/LaMini-GPT-124M', 'wikiword')": {
"A": 3.376395921350376,
"L_inf": 0.5548627966954982,
"alpha": 0.08563389830508474,
"sse": 0.022893087373991938
},
"('MBZUAI/LaMini-GPT-774M', 'flan')": {
"A": 4.952993658538474,
"L_inf": 0.7870052565309618,
"alpha": 0.09908474576271184,
"sse": 0.2400542702584885
},
"('MBZUAI/LaMini-GPT-774M', 'gigaword')": {
"A": 5.263660623382425,
"L_inf": 0.0,
"alpha": 0.10745762711864404,
"sse": 0.44323280410802207
},
"('MBZUAI/LaMini-GPT-774M', 'wikiword')": {
"A": 2.5700770849364734,
"L_inf": 0.40185298559482835,
"alpha": 0.07011457627118642,
"sse": 0.002812062716955058
},
"('cerebras/Cerebras-GPT-1.3B', 'flan')": {
"A": 3.0671660328777723,
"L_inf": 0.8665890220619048,
"alpha": 0.08167457627118642,
"sse": 0.01832322318723353
},
"('cerebras/Cerebras-GPT-1.3B', 'gigaword')": {
"A": 4.476965914016248,
"L_inf": 0.21586500856117852,
"alpha": 0.09885593220338983,
"sse": 0.09055411163147822
},
"('cerebras/Cerebras-GPT-1.3B', 'wikiword')": {
"A": 2.6958514343065025,
"L_inf": 0.5775014790256352,
"alpha": 0.07011457627118642,
"sse": 0.011516760278950221
},
"('cerebras/Cerebras-GPT-256M', 'flan')": {
"A": 3.634706072945405,
"L_inf": 1.1111855314883685,
"alpha": 0.08029016949152541,
"sse": 0.07819663306349012
},
"('cerebras/Cerebras-GPT-256M', 'gigaword')": {
"A": 5.243225803366305,
"L_inf": 0.12929230405004832,
"alpha": 0.10745762711864404,
"sse": 0.22934069038976262
},
"('cerebras/Cerebras-GPT-256M', 'wikiword')": {
"A": 4.028614188892544,
"L_inf": 0.724877093134493,
"alpha": 0.09885593220338983,
"sse": 0.01833369729089268
},
"('facebook/bart-base', 'flan')": {
"A": 6.254726539846071,
"L_inf": 0.9510091039471076,
"alpha": 0.11694915254237281,
"sse": 0.2540788846983032
},
"('facebook/bart-base', 'gigaword')": {
"A": 7.809486898113445,
"L_inf": 0.0,
"alpha": 0.13694915254237283,
"sse": 0.6790616782057259
},
"('facebook/bart-base', 'wikiword')": {
"A": 5.848232850566735,
"L_inf": 0.5101285027743712,
"alpha": 0.13593220338983047,
"sse": 0.046485498566010953
},
"('facebook/bart-large', 'flan')": {
"A": 4.72749327185259,
"L_inf": 0.813263285581497,
"alpha": 0.10745762711864404,
"sse": 0.06443381176516555
},
"('facebook/bart-large', 'gigaword')": {
"A": 7.9111202924205815,
"L_inf": 0.0,
"alpha": 0.14694915254237284,
"sse": 0.5847511268238709
},
"('facebook/bart-large', 'wikiword')": {
"A": 2.6358531287238205,
"L_inf": 0.7976718622793141,
"alpha": 0.11796610169491523,
"sse": 0.0008035057119018064
},
"('facebook/opt-1.3b', 'flan')": {
"A": 2.690458418113715,
"L_inf": 0.6317492094011479,
"alpha": 0.07011457627118642,
"sse": 0.00896705369896697
},
"('facebook/opt-1.3b', 'gigaword')": {
"A": 5.055017104788688,
"L_inf": 0.06216071258826933,
"alpha": 0.10745762711864405,
"sse": 0.15896734156520373
},
"('facebook/opt-1.3b', 'wikiword')": {
"A": 1.9989082093584953,
"L_inf": 0.3883783962588203,
"alpha": 0.05463796610169492,
"sse": 0.0005966937577908683
},
"('facebook/opt-350m', 'flan')": {
"A": 3.7735300303987582,
"L_inf": 0.9038808121467481,
"alpha": 0.08563389830508474,
"sse": 0.08866213184912522
},
"('facebook/opt-350m', 'gigaword')": {
"A": 6.633179880504192,
"L_inf": 0.01860470970974132,
"alpha": 0.12694915254237282,
"sse": 0.349162962231544
},
"('facebook/opt-350m', 'wikiword')": {
"A": 2.7612543989445277,
"L_inf": 0.6360779078248159,
"alpha": 0.08167457627118642,
"sse": 0.0027850673476467027
},
"('facebook/opt-6.7b', 'flan')": {
"A": 1.365125617665658,
"L_inf": 0.8999059285884865,
"alpha": 0.03699364406779661,
"sse": 0.0004314444689843593
},
"('facebook/opt-6.7b', 'gigaword')": {
"A": 1.2787823367249933,
"L_inf": 0.9146819209297886,
"alpha": 0.027880000000000002,
"sse": 0.0034799487006857556
},
"('facebook/opt-6.7b', 'wikiword')": {
"A": 1.6558332978869559,
"L_inf": 0.408940745237023,
"alpha": 0.04719830508474577,
"sse": 0.0002784725035281276
},
"('google/mt5-base', 'flan')": {
"A": 4.023037948558858,
"L_inf": 0.9088725993825086,
"alpha": 0.09885593220338983,
"sse": 0.02482242593421794
},
"('google/mt5-base', 'gigaword')": {
"A": 2.287289667115906,
"L_inf": 1.2250341225298993,
"alpha": 0.058017288135593224,
"sse": 0.013476308002459835
},
"('google/mt5-base', 'wikiword')": {
"A": 4.673677233571926,
"L_inf": 0.2712846154423211,
"alpha": 0.10745762711864405,
"sse": 0.035593744683382535
},
"('google/mt5-large', 'flan')": {
"A": 2.932396661171257,
"L_inf": 0.7555303504825739,
"alpha": 0.08167457627118642,
"sse": 0.010985283372535848
},
"('google/mt5-large', 'gigaword')": {
"A": 2.5411857996356715,
"L_inf": 1.1479844538474107,
"alpha": 0.06697525423728813,
"sse": 0.039205759001219365
},
"('google/mt5-large', 'wikiword')": {
"A": 3.736579330087355,
"L_inf": 0.39863249003965173,
"alpha": 0.09885593220338983,
"sse": 0.005007817242997015
},
"('gpt2', 'flan')": {
"A": 6.712728732449793,
"L_inf": 1.0928341966172603,
"alpha": 0.11745762711864405,
"sse": 0.5785501985298458
},
"('gpt2', 'gigaword')": {
"A": 6.245576858034009,
"L_inf": 0.030180644744962015,
"alpha": 0.11694915254237281,
"sse": 0.4917946060647004
},
"('gpt2', 'wikiword')": {
"A": 3.5477422717951628,
"L_inf": 0.6678701240663837,
"alpha": 0.09885593220338983,
"sse": 0.023337582935558168
},
"('t5-base', 'flan')": {
"A": 2.9579816171965967,
"L_inf": 0.6552739441287245,
"alpha": 0.07150389830508473,
"sse": 0.024602104806572565
},
"('t5-base', 'gigaword')": {
"A": 1.8289527080465815,
"L_inf": 0.41334140987512596,
"alpha": 0.16694915254237283,
"sse": 0.002126327534880776
},
"('t5-base', 'wikiword')": {
"A": 2.020115701049203,
"L_inf": 0.27566422140104585,
"alpha": 0.05463796610169492,
"sse": 0.002761225698689312
},
"('t5-small', 'flan')": {
"A": 3.3885713791993206,
"L_inf": 0.8860820725336442,
"alpha": 0.08167457627118642,
"sse": 0.026660987451640007
},
"('t5-small', 'gigaword')": {
"A": 1.7756175381171875,
"L_inf": 0.41731462928597385,
"alpha": 0.13796610169491524,
"sse": 0.002043258408125567
},
"('t5-small', 'wikiword')": {
"A": 2.4527737444128093,
"L_inf": 0.41988512499871733,
"alpha": 0.06697525423728813,
"sse": 0.00958337967879773
}
}
_PARAMS_BY_DATASET: Dict[str, Dict[str, float]] = {
"flan": {
"A": 4.215660115624637,
"L_inf": 0.571749209401148,
"alpha": 0.08029016949152541,
"sse": 43.25799096145912
},
"gigaword": {
"A": 4.64950666811947,
"L_inf": 0.0,
"alpha": 0.09333599576271186,
"sse": 54.41332124583563
},
"wikiword": {
"A": 3.219329078800361,
"L_inf": 0.20103710275697806,
"alpha": 0.07011457627118642,
"sse": 15.138361813331931
}
}
_GLOBAL_PARAMS: Dict[str, float] = {
"A": 4.1350908249146565,
"L_inf": 0.12178898498171782,
"alpha": 0.07745491525423727,
"sse": 164.34461735246174
}
def _select_params(group: str) -> Dict[str, float]:
# Exact group match first
if group in _PARAMS_BY_GROUP:
return _PARAMS_BY_GROUP[group]
# Try to parse tuple-like string to extract dataset fallback
try:
tpl = ast.literal_eval(group)
if isinstance(tpl, (list, tuple)) and len(tpl) >= 2:
dataset = tpl[1]
if dataset in _PARAMS_BY_DATASET:
return _PARAMS_BY_DATASET[dataset]
except Exception:
pass
# Fallback to global parameters
return _GLOBAL_PARAMS
def _predict_loss(n: float, params: Dict[str, float]) -> float:
# Scaling law: L(N) = L_inf + A * N^{-alpha}
L_inf = float(params['L_inf'])
A = float(params['A'])
alpha = float(params['alpha'])
# Safety: ensure positive N
n = float(n)
if not math.isfinite(n) or n <= 0:
return float('nan')
return L_inf + A * (n ** (-alpha))
def law(input_data: List[Dict[str, float]], group: str) -> List[Dict[str, float]]:
"""
Predicts output variables based on input variables according to a discovered scaling law.
Args:
input_data: A list of dictionaries, where each dictionary is a single data
point containing input variable names as keys and their
corresponding values.
group: The name of the experimental group for which to make predictions.
The functional form of the law must be the same for all groups,
but the constant parameters/coefficients can differ per group.
Returns:
A list of dictionaries, corresponding to the input_data list, with each
dictionary containing the predicted output variable(s).
"""
params = _select_params(group)
outputs: List[Dict[str, float]] = []
for row in input_data:
n = row.get('sft_data_size')
y = _predict_loss(n, params)
outputs.append({'sft_loss': float(y)})
return outputs