# Auto-generated scaling law implementation
# Model form: brier_score = a + b * log_flops + c * (log_flops ** 2)
from typing import List, Dict
# Per-group coefficients: a, b, c
COEFFS = {
'abstract_narrative_understanding': (-0.543140714074, 0.184726990056, -0.00100209571897),
'analogical_similarity': (-0.540575053774, 0.0279112874835, -0.0191758796727),
'arc': (-0.107112232715, 0.117619490399, -0.0368682063937),
'arithmetic': (-0.247532677712, 0.235370097975, -0.129978149629),
'conceptual_combinations': (-0.409345543131, 0.0969259552286, -0.0714835670647),
'hellaswag': (-0.0671968615465, 0.0980514543495, -0.0336706457557),
'hindu_knowledge': (-0.410317419378, -0.0311435105549, -0.0344023889601),
'mmlu': (-0.48036465022, -0.0629704348879, 0.0114762642805),
'parsinlu_qa_mc': (-0.434950718068, 0.0989058373264, -0.0565673953741),
}
GLOBAL_COEFFS = (-0.378439693837, 0.0773755683686, 0.00264467324727)
def _predict_single(x: float, coeffs: tuple[float, float, float]) -> float:
a, b, c = coeffs
return a + b * x + c * (x ** 2)
def law(input_data: list[dict[str, float]], group: str) -> list[dict[str, float]]:
"""
Predicts output variables based on input variables according to a discovered scaling law.
Args:
input_data: A list of dictionaries, where each dictionary is a single data
point containing input variable names as keys and their
corresponding values.
group: The name of the experimental group for which to make predictions.
The functional form of the law is the same for all groups,
but the coefficients can differ per group.
Returns:
A list of dictionaries, corresponding to the input_data list, with each
dictionary containing the predicted output variable(s).
"""
coeffs = COEFFS.get(group, GLOBAL_COEFFS)
outputs: list[dict[str, float]] = []
for row in input_data:
if 'log_flops' not in row:
raise ValueError("Each input row must contain 'log_flops'.")
x = float(row['log_flops'])
yhat = _predict_single(x, coeffs)
outputs.append({'brier_score': float(yhat)})
return outputs