from __future__ import annotations
from typing import Dict, List
# Quadratic (parabolic) scaling law per group:
# brier_score = A_g * (log_flops)**2 + B_g * (log_flops) + C_g
# The same functional form is used across groups; only (A_g, B_g, C_g) differ.
_PARAMS: Dict[str, Dict[str, float]] = {
# Fitted via ordinary least squares on the provided dataset
# groups: arc, conceptual_combinations, abstract_narrative_understanding, arithmetic,
# parsinlu_qa_mc, hellaswag, analogical_similarity, mmlu, hindu_knowledge
"arc": {"A": -0.03686820639366876, "B": 0.1176194903989729, "C": -0.10711223271542945},
"conceptual_combinations": {"A": -0.07148356706471536, "B": 0.09692595522861094, "C": -0.40934554313141797},
"abstract_narrative_understanding": {"A": -0.001002095718968019, "B": 0.18472699005645857, "C": -0.5431407140744654},
"arithmetic": {"A": -0.12997814962868384, "B": 0.2353700979752282, "C": -0.24753267771220774},
"parsinlu_qa_mc": {"A": -0.05656739537407183, "B": 0.09890583732640096, "C": -0.43495071806820157},
"hellaswag": {"A": -0.033670645755682356, "B": 0.09805145434945439, "C": -0.06719686154646048},
"analogical_similarity": {"A": -0.019175879672698144, "B": 0.02791128748347238, "C": -0.540575053773558},
"mmlu": {"A": 0.011476264280523023, "B": -0.06297043488789655, "C": -0.48036465021983477},
"hindu_knowledge": {"A": -0.03440238896008094, "B": -0.031143510554884568, "C": -0.41031741937809096},
}
# Global fallback (in case of an unseen group)
_GLOBAL: Dict[str, float] = {
"A": 0.0026446732472713928,
"B": 0.07737556836857278,
"C": -0.3784396938370408,
}
def _predict_single(log_flops: float, coeffs: Dict[str, float]) -> float:
return coeffs["A"] * (log_flops ** 2) + coeffs["B"] * log_flops + coeffs["C"]
def law(input_data: List[Dict[str, float]], group: str) -> List[Dict[str, float]]:
"""
Predicts output variables based on input variables according to a discovered scaling law.
Args:
input_data: A list of dictionaries, where each dictionary is a single data
point containing input variable names as keys and their
corresponding values. Expected key: 'log_flops'.
group: The name of the experimental group for which to make predictions.
The functional form of the law is the same for all groups
(quadratic in log_flops), but the coefficients differ per group.
Returns:
A list of dictionaries, corresponding to the input_data list, with each
dictionary containing the predicted output variable(s): {'brier_score': float}.
"""
coeffs = _PARAMS.get(group, _GLOBAL)
outputs: List[Dict[str, float]] = []
for row in input_data:
if "log_flops" not in row:
raise KeyError("Each input row must include 'log_flops'.")
x = float(row["log_flops"]) # ensure numeric
y = _predict_single(x, coeffs)
outputs.append({"brier_score": float(y)})
return outputs