from typing import List, Dict
# Quadratic coefficients per group: brier_score = a * x^2 + b * x + c
_COEFS: Dict[str, tuple[float, float, float]] = {
'mmlu': (0.011476264280523694, -0.06297043488789662, -0.480364650219835),
'parsinlu_qa_mc': (-0.05656739537407183, 0.0989058373264011, -0.43495071806820146),
'arithmetic': (-0.12997814962868387, 0.23537009797522832, -0.2475326777122078),
'hindu_knowledge': (-0.034402388960081354, -0.031143510554884814, -0.4103174193780911),
'analogical_similarity': (-0.019175879672698435, 0.0279112874834725, -0.5405750537735581),
'conceptual_combinations': (-0.07148356706471508, 0.09692595522861085, -0.40934554313141813),
'hellaswag': (-0.033670645755682356, 0.09805145434945438, -0.06719686154646047),
'arc': (-0.036868206393668744, 0.11761949039897288, -0.1071122327154294),
'abstract_narrative_understanding': (-0.001002095718967912, 0.18472699005645873, -0.5431407140744655),
}
def law(input_data: list[dict[str, float]], group: str) -> list[dict[str, float]]:
"""
Predicts output variables based on input variables according to a discovered scaling law.
Args:
input_data: A list of dictionaries, where each dictionary is a single data
point containing input variable names as keys and their
corresponding values.
group: The name of the experimental group for which to make predictions.
The functional form of the law is the same for all groups,
but constant parameters/coefficients may differ per group.
Returns:
A list of dictionaries, corresponding to the input_data list, with each
dictionary containing the predicted output variable(s).
"""
# Retrieve group-specific coefficients
try:
a, b, c = _COEFS[group]
except KeyError:
raise ValueError(f"Unknown group: {group}")
# Compute predictions
predictions: list[dict[str, float]] = []
for point in input_data:
x = float(point['log_flops'])
y = a * x * x + b * x + c
predictions.append({'brier_score': y})
return predictions