# Auto-generated scaling law implementation
# U-shaped quadratic in log_flops: y = a * (log_flops - x0)**2 + c
from typing import List, Dict
# Per-group parameters fitted from /app/data
_PARAMS: dict[str, dict[str, float]] = {
"mmlu": {
"a": 0.011476264280523023,
"x0": 2.7435075277399728,
"c": -0.5667445812898367
},
"parsinlu_qa_mc": {
"a": 0.0,
"x0": 1.1106711713084738,
"c": -0.434241496854291
},
"arithmetic": {
"a": 0.0,
"x0": 0.854807431011725,
"c": -0.21644143686194878
},
"hindu_knowledge": {
"a": 0.0,
"x0": 0.15012816206281548,
"c": -0.4266515719307887
},
"analogical_similarity": {
"a": 0.0,
"x0": 0.9563601615678085,
"c": -0.5431939971887069
},
"conceptual_combinations": {
"a": 0.0,
"x0": 0.562905378864399,
"c": -0.41087701400682625
},
"hellaswag": {
"a": 0.0,
"x0": 1.1106711713084738,
"c": -0.02325897962914785
},
"arc": {
"a": 0.0,
"x0": 1.1106711713084738,
"c": -0.04761015603522527
},
"abstract_narrative_understanding": {
"a": 0.0,
"x0": 1.1106711713084738,
"c": -0.33990324185677395
}
}
_DEF_GROUP = 'default'
def _predict_one(x: float, g: str) -> float:
gkey = str(g)
if gkey not in _PARAMS:
# fallback to default if present, else use average of known params
if _DEF_GROUP in _PARAMS:
p = _PARAMS[_DEF_GROUP]
else:
# mean params
import math
if _PARAMS:
a = sum(d['a'] for d in _PARAMS.values())/len(_PARAMS)
x0 = sum(d['x0'] for d in _PARAMS.values())/len(_PARAMS)
c = sum(d['c'] for d in _PARAMS.values())/len(_PARAMS)
p = {'a':a,'x0':x0,'c':c}
else:
p = {'a':0.0,'x0':0.0,'c':0.0}
else:
p = _PARAMS[gkey]
return p['a'] * (x - p['x0'])**2 + p['c']
def law(input_data: List[Dict[str, float]], group: str) -> List[Dict[str, float]]:
"""
Predicts output variables based on input variables according to a discovered scaling law.
Args:
input_data: A list of dictionaries, where each dictionary is a single data
point containing input variable names as keys and their
corresponding values.
group: The name of the experimental group for which to make predictions.
The functional form of the law must be the same for all groups,
but the constant parameters/coefficients can differ per group.
Returns:
A list of dictionaries, corresponding to the input_data list, with each
dictionary containing the predicted output variable(s).
"""
out = []
for row in input_data:
# Expect 'log_flops' key
x = float(row.get('log_flops'))
yhat = _predict_one(x, group)
out.append({'brier_score': float(yhat)})
return out