from typing import List, Dict
# Quadratic (U-shaped) scaling law in log_flops:
# brier_score_hat = a_g * x^2 + b_g * x + c_g
# where x = log_flops and (a_g, b_g, c_g) depend on the group.
# Coefficients were fitted by least squares (numpy.polyfit degree=2)
# on the provided dataset at /app/data.
COEFFS: Dict[str, Dict[str, float]] = {
# a, b, c per group
"abstract_narrative_understanding": {
"a": -0.001002095718967912,
"b": 0.18472699005645873,
"c": -0.5431407140744655,
},
"analogical_similarity": {
"a": -0.019175879672698435,
"b": 0.0279112874834725,
"c": -0.5405750537735581,
},
"arc": {
"a": -0.036868206393668744,
"b": 0.11761949039897288,
"c": -0.1071122327154294,
},
"arithmetic": {
"a": -0.12997814962868387,
"b": 0.23537009797522832,
"c": -0.2475326777122078,
},
"conceptual_combinations": {
"a": -0.07148356706471508,
"b": 0.09692595522861085,
"c": -0.40934554313141813,
},
"hellaswag": {
"a": -0.033670645755682356,
"b": 0.09805145434945438,
"c": -0.06719686154646047,
},
"hindu_knowledge": {
"a": -0.034402388960081354,
"b": -0.031143510554884814,
"c": -0.4103174193780911,
},
"mmlu": {
"a": 0.011476264280523694,
"b": -0.06297043488789662,
"c": -0.480364650219835,
},
"parsinlu_qa_mc": {
"a": -0.05656739537407183,
"b": 0.0989058373264011,
"c": -0.43495071806820146,
},
}
# Global fallback (in case of unseen group)
GLOBAL_COEFFS = {"a": 0.002644673247271387, "b": 0.07737556836857276, "c": -0.3784396938370407}
def _predict(log_flops: float, coeffs: Dict[str, float]) -> float:
a = coeffs["a"]
b = coeffs["b"]
c = coeffs["c"]
return a * (log_flops ** 2) + b * log_flops + c
def law(input_data: List[Dict[str, float]], group: str) -> List[Dict[str, float]]:
"""
Predicts output variables based on input variables according to a discovered scaling law.
Args:
input_data: A list of dictionaries, where each dictionary is a single data
point containing input variable names as keys and their
corresponding values.
group: The name of the experimental group for which to make predictions.
The functional form of the law must be the same for all groups,
but the constant parameters/coefficients can differ per group.
Returns:
A list of dictionaries, corresponding to the input_data list, with each
dictionary containing the predicted output variable(s).
"""
coeffs = COEFFS.get(group, GLOBAL_COEFFS)
out: List[Dict[str, float]] = []
for row in input_data:
if "log_flops" not in row:
raise KeyError("Each input item must include 'log_flops'.")
x = float(row["log_flops"]) # computation budget in log-scale
y_hat = _predict(x, coeffs)
out.append({"brier_score": float(y_hat)})
return out