from __future__ import annotations
from math import log
from typing import Dict, List
# Discovered scaling law:
# For each output loss_domain_i, predict as a linear function of the
# logarithms of all five domain proportions (with a tiny epsilon for stability):
# \hat{L}_{g,i} = bias_{g,i} + sum_j w_{g,i,j} * log(p_j + eps)
# The functional form is fixed; parameters (biases and weights) depend on group g.
# Coefficients were fit by ordinary least squares on the provided dataset.
EPS = 1e-12
# Per-group, per-output coefficients learned on the provided dataset.
# Keys: group -> domain_{i} -> {bias, log_p1..log_p5}
COEFFS: Dict[str, Dict[str, Dict[str, float]]] = {
"160M": {
"domain_1": {
"bias": 2.4422650614387416,
"log_p1": -0.02277638164592572,
"log_p2": -0.0010526071930561662,
"log_p3": 0.0023643145622796475,
"log_p4": -0.0025583225480046595,
"log_p5": 0.004378832847069199,
},
"domain_2": {
"bias": 3.3408696428701075,
"log_p1": 0.0012198254051165147,
"log_p2": -0.005388376734279845,
"log_p3": 3.717373734748179e-05,
"log_p4": 0.001569209092862663,
"log_p5": -0.002615879540461004,
},
"domain_3": {
"bias": 2.8574312884633413,
"log_p1": 0.0018047356370065415,
"log_p2": 0.0005929693186934749,
"log_p3": -0.016974557643524405,
"log_p4": 0.0016523205850350147,
"log_p5": -0.0022616680413822805,
},
"domain_4": {
"bias": 1.4313732131696744,
"log_p1": -5.1765954939435764e-05,
"log_p2": 0.00089014156596828,
"log_p3": 0.00025768087959735246,
"log_p4": -0.019712789545137503,
"log_p5": 0.0031631020272182907,
},
"domain_5": {
"bias": 3.2257864919233756,
"log_p1": 0.0029557696510542136,
"log_p2": 0.0026611908853740054,
"log_p3": -0.0008885188514808986,
"log_p4": 0.002963471212841412,
"log_p5": -0.014216518710413713,
},
},
"305M": {
"domain_1": {
"bias": 2.2924869661791267,
"log_p1": -0.022229067565166443,
"log_p2": -0.001013806799899392,
"log_p3": 0.00223467186554643,
"log_p4": -0.002645873902860267,
"log_p5": 0.003982864991197694,
},
"domain_2": {
"bias": 3.178656522099369,
"log_p1": 0.001107302828752052,
"log_p2": -0.005157973767038715,
"log_p3": 2.3050110124257784e-05,
"log_p4": 0.001316251689675477,
"log_p5": -0.0028425410208344046,
},
"domain_3": {
"bias": 2.709373598108862,
"log_p1": 0.0016635063654583334,
"log_p2": 0.0006628818961311013,
"log_p3": -0.017547370822364454,
"log_p4": 0.0014206959628518868,
"log_p5": -0.002139388852229081,
},
"domain_4": {
"bias": 1.3211315997161879,
"log_p1": -0.00024768872185789104,
"log_p2": 0.0007428895039710641,
"log_p3": 0.0002693184497085781,
"log_p4": -0.01890218970791502,
"log_p5": 0.0030105938208213993,
},
"domain_5": {
"bias": 3.0545845875446473,
"log_p1": 0.002854476470176697,
"log_p2": 0.0026446139943498434,
"log_p3": -0.0008847006391603586,
"log_p4": 0.002650748867952872,
"log_p5": -0.014590961474216001,
},
},
"410M": {
"domain_1": {
"bias": 2.2343717924971283,
"log_p1": -0.021983874135919367,
"log_p2": -0.0006526879839640017,
"log_p3": 0.0019063557364627178,
"log_p4": -0.0022375323165275815,
"log_p5": 0.003693691571213775,
},
"domain_2": {
"bias": 3.1081386488108924,
"log_p1": 0.0012757792324108383,
"log_p2": -0.004861060292668727,
"log_p3": -0.00025971561825861964,
"log_p4": 0.0015980855298738693,
"log_p5": -0.003151617877548417,
},
"domain_3": {
"bias": 2.6457079548318867,
"log_p1": 0.0017179901900188285,
"log_p2": 0.0008184656194115004,
"log_p3": -0.017971019588506255,
"log_p4": 0.0016870680564514456,
"log_p5": -0.0023896692822175354,
},
"domain_4": {
"bias": 1.2820975851894605,
"log_p1": -0.00043994437500341244,
"log_p2": 0.0010961729632694766,
"log_p3": 0.00016179126609633246,
"log_p4": -0.01836235643857456,
"log_p5": 0.0027422260133483827,
},
"domain_5": {
"bias": 2.9806037129271403,
"log_p1": 0.0028948032957862636,
"log_p2": 0.002904176943938192,
"log_p3": -0.0011826414336436638,
"log_p4": 0.00291564957025589,
"log_p5": -0.015090031561988756,
},
},
"70M": {
"domain_1": {
"bias": 2.7516102214862785,
"log_p1": -0.024366773302687862,
"log_p2": -0.00140949302483008,
"log_p3": 0.0026992835128934697,
"log_p4": -0.0031027937470278375,
"log_p5": 0.0048722418014892025,
},
"domain_2": {
"bias": 3.6719317859686043,
"log_p1": 0.0011625721914211736,
"log_p2": -0.005916469283232795,
"log_p3": 0.0001159058459700631,
"log_p4": 0.001253484647821796,
"log_p5": -0.0023991782312973043,
},
"domain_3": {
"bias": 3.148901142003107,
"log_p1": 0.0019476114702619417,
"log_p2": 0.0005957023862280053,
"log_p3": -0.01776813563794145,
"log_p4": 0.0013179243483677264,
"log_p5": -0.002183184626543552,
},
"domain_4": {
"bias": 1.6665569458710474,
"log_p1": -0.00025847451148676925,
"log_p2": 0.0008758521651152874,
"log_p3": 0.00029735558927713497,
"log_p4": -0.022152540972955964,
"log_p5": 0.0033608606565960417,
},
"domain_5": {
"bias": 3.5732020332438177,
"log_p1": 0.003066036404538771,
"log_p2": 0.002572530514690887,
"log_p3": -0.0008951939451058563,
"log_p4": 0.0027436364335349703,
"log_p5": -0.013811286335446366,
},
},
}
# Fallback pooled coefficients (across all groups) in case an unknown group is requested.
FALLBACK: Dict[str, Dict[str, float]] = {
"domain_1": {
"bias": 2.4301835104003184,
"log_p1": -0.02283902416242487,
"log_p2": -0.001032148750437415,
"log_p3": 0.002301156419295567,
"log_p4": -0.0026361306286050916,
"log_p5": 0.004231907802742471,
},
"domain_2": {
"bias": 3.3248991499372433,
"log_p1": 0.0011913699144251472,
"log_p2": -0.005330970019305012,
"log_p3": -2.0896481204204573e-05,
"log_p4": 0.001434257740058452,
"log_p5": -0.0027523041675352787,
},
"domain_3": {
"bias": 2.8403534958517995,
"log_p1": 0.0017834609156864088,
"log_p2": 0.00066750480511602,
"log_p3": -0.01756527092308412,
"log_p4": 0.0015195022381765178,
"log_p5": -0.0022434777005931115,
},
"domain_4": {
"bias": 1.4252898359865924,
"log_p1": -0.0002494683908218775,
"log_p2": 0.0009012640495810252,
"log_p3": 0.0002465365461698426,
"log_p4": -0.019782469166145775,
"log_p5": 0.003069195629496027,
},
"domain_5": {
"bias": 3.208544206409745,
"log_p1": 0.0029427714553889924,
"log_p2": 0.0026956280845882408,
"log_p3": -0.0009627637173476962,
"log_p4": 0.0028183765211462894,
"log_p5": -0.014427199520516213,
},
}
def _predict_one(row: Dict[str, float], coeffs: Dict[str, Dict[str, float]]) -> Dict[str, float]:
# Build log-proportion features with numerical stability for zeros
logs = [log(float(row.get(f"proportion_domain_{j}", 0.0)) + EPS) for j in range(1, 6)]
out: Dict[str, float] = {}
for i in range(1, 6):
c = coeffs.get(f"domain_{i}")
if c is None:
# If missing, fall back to pooled per-domain coefficients
c = FALLBACK[f"domain_{i}"]
val = c["bias"]
for j in range(1, 6):
val += c[f"log_p{j}"] * logs[j - 1]
out[f"loss_domain_{i}"] = float(val)
return out
def law(input_data: List[Dict[str, float]], group: str) -> List[Dict[str, float]]:
"""
Predicts output variables based on input variables according to a discovered scaling law.
Args:
input_data: A list of dictionaries, where each dictionary is a single data
point containing input variable names as keys and their
corresponding values.
group: The name of the experimental group for which to make predictions.
The functional form of the law is the same for all groups,
but the constant parameters/coefficients can differ per group.
Returns:
A list of dictionaries, corresponding to the input_data list, with each
dictionary containing the predicted output variable(s): loss_domain_1..5.
"""
# Select coefficients for the provided group (fallback if unknown)
coeffs = COEFFS.get(group)
if coeffs is None:
# Unknown group: use pooled (group-agnostic) coefficients for all domains
coeffs = FALLBACK
return [_predict_one(row, coeffs) for row in input_data]