from __future__ import annotations
import math
from typing import Dict, List
# Coefficients fitted on the provided dataset at /app/data
# We use a single functional form for all groups (vertex form of a quadratic):
# brier_score = alpha[g] + beta[g] * (log_flops - x0[g])**2
# where (alpha, beta, x0) are constants that depend on the group.
# These were obtained via least-squares fitting (equivalent to converting from
# the polynomial coefficients c2, c1, c0 with beta=c2, x0=-c1/(2*c2), alpha=y(x0)).
PARAMS: Dict[str, Dict[str, float]] = {
# alpha, beta, x0 per group
"abstract_narrative_understanding": {
"alpha": 7.970033279539732,
"beta": -0.001002095718967912,
"x0": 92.1703319153556,
},
"analogical_similarity": {
"alpha": -0.530418544889976,
"beta": -0.019175879672698435,
"x0": 0.7277707192544356,
},
"arc": {
"alpha": -0.013302783633511228,
"beta": -0.036868206393668744,
"x0": 1.5951344248085157,
},
"arithmetic": {
"alpha": -0.1409780699201859,
"beta": -0.12997814962868387,
"x0": 0.9054217906918345,
},
"conceptual_combinations": {
"alpha": -0.3764895974128952,
"beta": -0.07148356706471508,
"x0": 0.6779596990512687,
},
"hellaswag": {
"alpha": 0.0041864419532996605,
"beta": -0.033670645755682356,
"x0": 1.4560376278632394,
},
"hindu_knowledge": {
"alpha": -0.4032690843714171,
"beta": -0.034402388960081354,
"x0": -0.4526358705934933,
},
"mmlu": {
"alpha": -0.5667445812898321,
"beta": 0.011476264280523694,
"x0": 2.7435075277398155,
},
"parsinlu_qa_mc": {
"alpha": -0.39171748895915737,
"beta": -0.05656739537407183,
"x0": 0.8742300814130773,
},
}
# Fallback global fit (across all groups), used if an unknown group name is passed.
GLOBAL_FALLBACK = {"c2": 0.002644673247271387, "c1": 0.07737556836857276, "c0": -0.3784396938370407}
def _predict_vertex(log_flops: float, params: Dict[str, float]) -> float:
"""Evaluate vertex-form quadratic with given parameters."""
return params["alpha"] + params["beta"] * (log_flops - params["x0"]) ** 2
def _predict_poly(log_flops: float, c2: float, c1: float, c0: float) -> float:
"""Evaluate standard polynomial form y = c2*x^2 + c1*x + c0."""
return c2 * log_flops * log_flops + c1 * log_flops + c0
def law(input_data: list[dict[str, float]], group: str) -> list[dict[str, float]]:
"""
Predicts output variables based on input variables according to a discovered scaling law.
Args:
input_data: A list of dictionaries, where each dictionary is a single data
point containing input variable names as keys and their
corresponding values.
group: The name of the experimental group for which to make predictions.
The functional form of the law must be the same for all groups,
but the constant parameters/coefficients can differ per group.
Returns:
A list of dictionaries, corresponding to the input_data list, with each
dictionary containing the predicted output variable(s).
"""
results: List[Dict[str, float]] = []
# Resolve coefficients for the requested group (or use global fallback)
params = PARAMS.get(group)
use_fallback = params is None
for row in input_data:
if "log_flops" in row and isinstance(row["log_flops"], (int, float)):
x = float(row["log_flops"]) # primary expected input
elif "flops" in row and isinstance(row["flops"], (int, float)):
# Best-effort fallback: approximate log_flops as base-10 logarithm of flops.
# (This branch is only used if downstream callers provide 'flops' instead
# of 'log_flops'. The training data already supplies 'log_flops'.)
x = math.log10(float(row["flops"])) if row["flops"] > 0 else float("nan")
else:
x = float("nan")
if use_fallback:
y = _predict_poly(x, GLOBAL_FALLBACK["c2"], GLOBAL_FALLBACK["c1"], GLOBAL_FALLBACK["c0"]) if math.isfinite(x) else float("nan")
else:
y = _predict_vertex(x, params) if math.isfinite(x) else float("nan")
results.append({"brier_score": float(y)})
return results