# Auto-generated scaling law implementation
# Functional form (shared across groups):
# For each k in 1..5:
# loss_domain_k = intercept_k(group) + sum_j weights_kj(group) * log(proportion_domain_j + EPS)
# Where EPS = 1e-09 for numerical stability when proportions are very small.
from math import log
from typing import List, Dict
EPS = 1e-09
COEFFS = {
"160M": {
"loss_domain_1": {
"intercept": 2.4300949275826182,
"weights": [
-0.031068969226910582,
-0.001251970877444727,
0.0030660145955268667,
-0.0031923386958230627,
0.005666961729262434
]
},
"loss_domain_2": {
"intercept": 3.337559043041566,
"weights": [
0.0016673078859499708,
-0.0074170755968738225,
3.951373748949835e-05,
0.0021209041901585806,
-0.003650890207385123
]
},
"loss_domain_3": {
"intercept": 2.8401818458556174,
"weights": [
0.002460792554312405,
0.0007625714631167231,
-0.023536811266176717,
0.002221882872273453,
-0.0031186174972866886
]
},
"loss_domain_4": {
"intercept": 1.416827910017676,
"weights": [
-8.308604862155428e-06,
0.0012190503004574596,
0.00032399676863148686,
-0.027078844363725037,
0.004010845050737586
]
},
"loss_domain_5": {
"intercept": 3.224543906049842,
"weights": [
0.003905427350717464,
0.0034498008479645842,
-0.0011438176738371758,
0.0038638961855052053,
-0.019328119858345046
]
}
},
"305M": {
"loss_domain_1": {
"intercept": 2.280243138043643,
"weights": [
-0.030323002210805715,
-0.001211700929338305,
0.0028947830670476163,
-0.003324913376723421,
0.005113800264772562
]
},
"loss_domain_2": {
"intercept": 3.1750876801913237,
"weights": [
0.00150916925021668,
-0.007107484565596503,
2.2128083411678632e-05,
0.0017728620875118904,
-0.003972454524642374
]
},
"loss_domain_3": {
"intercept": 2.6913398939739,
"weights": [
0.0022708344859972973,
0.0008607850265603164,
-0.024331792555795703,
0.001910605212900081,
-0.0029724526398083326
]
},
"loss_domain_4": {
"intercept": 1.3069649592755754,
"weights": [
-0.0002767577345750335,
0.0010150273630255805,
0.0003398702171026176,
-0.02596334582661601,
0.003801676322267153
]
},
"loss_domain_5": {
"intercept": 3.0528011051041433,
"weights": [
0.0037594659546520076,
0.0034195441188783746,
-0.0011368975113422175,
0.0034318689601916725,
-0.01984609557908483
]
}
},
"410M": {
"loss_domain_1": {
"intercept": 2.222420776701192,
"weights": [
-0.0299923842986247,
-0.0007219779189471512,
0.002445903709102556,
-0.0027711721407406966,
0.004740901934744721
]
},
"loss_domain_2": {
"intercept": 3.104708939329123,
"weights": [
0.0017336451412636902,
-0.006704834426433428,
-0.00036440290744305196,
0.0021542694049179936,
-0.004373253420089363
]
},
"loss_domain_3": {
"intercept": 2.627413222019661,
"weights": [
0.0023403919816331487,
0.0010679528945524257,
-0.02491463877855555,
0.002264951898769409,
-0.003300544791425321
]
},
"loss_domain_4": {
"intercept": 1.268317554857684,
"weights": [
-0.0005477525270926719,
0.001495543684942191,
0.00019725923029861728,
-0.025223453889896486,
0.0034567257553997934
]
},
"loss_domain_5": {
"intercept": 2.9787362389662326,
"weights": [
0.0038069030391143168,
0.003769753651834665,
-0.0015455123055104936,
0.003789696705093632,
-0.020506176044717026
]
}
},
"70M": {
"loss_domain_1": {
"intercept": 2.738245966324205,
"weights": [
-0.033247355909198334,
-0.0017297973305077973,
0.0035157990300376508,
-0.003918220474936535,
0.006317934127290131
]
},
"loss_domain_2": {
"intercept": 3.6680087213236505,
"weights": [
0.0015828220455031585,
-0.008139958599724776,
0.00014447856565839406,
0.0016881696288695383,
-0.0033595783088106986
]
},
"loss_domain_3": {
"intercept": 3.1304945178100705,
"weights": [
0.0026532896643926083,
0.0007600084197338211,
-0.024637174874146756,
0.001757289151853728,
-0.0030248438725573442
]
},
"loss_domain_4": {
"intercept": 1.6498551367668615,
"weights": [
-0.0002889005529800089,
0.001194603021950033,
0.00037667659236133267,
-0.030432861224971082,
0.004231626123725879
]
},
"loss_domain_5": {
"intercept": 3.5718538704762426,
"weights": [
0.004053652875855752,
0.0033280287851346077,
-0.0011542215280026637,
0.0035637709733213644,
-0.01878431722025504
]
}
},
"DEFAULT": {
"loss_domain_1": {
"intercept": 2.417751202162915,
"weights": [
-0.03115792791138483,
-0.001228861764059495,
0.0029806251004286725,
-0.0033016611720559285,
0.005459899514017462
]
},
"loss_domain_2": {
"intercept": 3.3213410959714156,
"weights": [
0.0016232360807333748,
-0.007342338297157132,
-3.957063022087023e-05,
0.0019340513278645007,
-0.00383904411523189
]
},
"loss_domain_3": {
"intercept": 2.8223573699148123,
"weights": [
0.002431327171583865,
0.0008628294509908216,
-0.024355104368668678,
0.0020386822839491676,
-0.003104114700269422
]
},
"loss_domain_4": {
"intercept": 1.4104913902294491,
"weights": [
-0.00028042985487746743,
0.001231056092593816,
0.0003094507020985136,
-0.027174626326302152,
0.003875218313032603
]
},
"loss_domain_5": {
"intercept": 3.2069837801491152,
"weights": [
0.003881362305084885,
0.0034917818509530582,
-0.0012451122546731377,
0.0036623082060279685,
-0.019616177175600488
]
}
}
}
PROP_ORDER = ["proportion_domain_1", "proportion_domain_2", "proportion_domain_3", "proportion_domain_4", "proportion_domain_5"]
LOSS_ORDER = ["loss_domain_1", "loss_domain_2", "loss_domain_3", "loss_domain_4", "loss_domain_5"]
def _select_group(group: str) -> Dict[str, Dict[str, float]]:
# Use exact match if available; otherwise fallback to DEFAULT
if group in COEFFS:
return COEFFS[group]
# Try case-insensitive match
for g in COEFFS:
if g.lower() == (group or "").lower():
return COEFFS[g]
return COEFFS["DEFAULT"]
def law(input_data: List[Dict[str, float]], group: str) -> List[Dict[str, float]]:
"""
Predicts output variables based on input variables according to a discovered scaling law.
Args:
input_data: A list of dictionaries, where each dictionary is a single data
point containing input variable names as keys and their
corresponding values.
group: The name of the experimental group for which to make predictions.
The functional form of the law is the same for all groups, while
coefficients differ by group.
Returns:
A list of dictionaries, corresponding to the input_data list, with each
dictionary containing the predicted output variable(s) for keys in LOSS_ORDER.
"""
coeffs = _select_group(group)
outputs = []
for row in input_data:
# Build feature vector in canonical order
logs = [log(max(float(row.get(name, 0.0)), 0.0) + EPS) for name in PROP_ORDER]
pred = {}
for lkey in LOSS_ORDER:
c = coeffs[lkey]
y = c["intercept"] + sum(w * x for w, x in zip(c["weights"], logs))
pred[lkey] = float(y)
outputs.append(pred)
return outputs