from typing import List, Dict
# Discovered U-shaped scaling law (per-group parameters)
# Functional form: brier_score = d + a * (log_flops - c)**2, with a >= 0
# Per-group coefficients fitted via constrained non-linear least squares
COEFFS: Dict[str, Dict[str, float]] = {
"abstract_narrative_understanding": {
"a": 0.008467468113573256,
"c": -10.0,
"d": -1.3910788665280118,
},
"analogical_similarity": {
"a": 0.00013496608274505986,
"c": -10.0,
"d": -0.5594784378689742,
},
"arc": {
"a": 0.002670618953008418,
"c": -10.0,
"d": -0.3791484043679726,
},
"arithmetic": {
"a": 0.003889937618786338,
"c": -10.0,
"d": -0.6770289854837191,
},
"conceptual_combinations": {
"a": 0.002411414906857867,
"c": -10.0,
"d": -0.6810586583653832,
},
"hellaswag": {
"a": 0.0020058213876494877,
"c": -10.0,
"d": -0.27226736346486907,
},
"hindu_knowledge": {
"a": 0.0015094214985086043,
"c": 10.0,
"d": -0.5735731165193414,
},
"mmlu": {
"a": 0.01147626953954232,
"c": 2.7435063200062424,
"d": -0.5667445498636151,
},
"parsinlu_qa_mc": {
"a": 0.00032584434240010923,
"c": -10.0,
"d": -0.4746927422057731,
},
}
# Fallback coefficients (pooled fit across all groups)
DEFAULT_COEFFS = {"a": 0.023678322162278612, "c": -0.8996294548824371, "d": -0.3978978627033718}
def _get_log_flops(row: Dict[str, float]) -> float:
if "log_flops" in row and row["log_flops"] is not None:
return float(row["log_flops"])
# Optional fallback if only raw flops are provided
flops = row.get("flops", None)
if flops is not None and flops > 0:
# log10(flops)
import math
return float(math.log10(flops))
return 0.0
def law(input_data: List[Dict[str, float]], group: str) -> List[Dict[str, float]]:
"""
Predicts output variables based on input variables according to a discovered scaling law.
Args:
input_data: A list of dictionaries, where each dictionary is a single data
point containing input variable names as keys and their
corresponding values.
group: The name of the experimental group for which to make predictions.
The functional form of the law must be the same for all groups,
but the constant parameters/coefficients can differ per group.
Returns:
A list of dictionaries, corresponding to the input_data list, with each
dictionary containing the predicted output variable(s).
"""
params = COEFFS.get(group, DEFAULT_COEFFS)
a, c, d = float(params["a"]), float(params["c"]), float(params["d"])
outputs: List[Dict[str, float]] = []
for row in input_data:
x = _get_log_flops(row)
y_hat = d + a * (x - c) ** 2
outputs.append({"brier_score": float(y_hat)})
return outputs