from __future__ import annotations
import math
from typing import List, Dict
# Fixed feature ordering
PROP_COLS = [
"proportion_domain_1",
"proportion_domain_2",
"proportion_domain_3",
"proportion_domain_4",
"proportion_domain_5",
]
LOSS_COLS = [
"loss_domain_1",
"loss_domain_2",
"loss_domain_3",
"loss_domain_4",
"loss_domain_5",
]
EPS = 1e-6
# Coefficients learned per group for the law:
# loss_i = intercept[i] + sum_j coef_linear[i][j] * p_j + sum_j coef_log[i][j] * log(p_j + EPS)
COEFFS: Dict[str, Dict[str, list]] = {
"160M": {
"intercept": [
2.469311683337708,
3.3141620411008277,
2.5975875154705848,
1.3440867180535057,
3.2488739962835567,
],
"coef_linear": [
[-0.39242870031019833, 0.1449840040105368, 0.20870621607378334, 0.012988956774962533, 0.02574952345091407],
[0.2073851738757125, -0.5034944849958123, 0.08099673956846185, 0.2119097462425706, 0.0032028253090607807],
[0.4222507474605135, 0.33048349239799873, -1.3184032511886987, 0.3397661471062194, 0.22590286422400102],
[0.1075280031010361, 0.3278752202366596, 0.018326424467131473, -0.5403846909284411, 0.08665504312361116],
[0.1224578633506513, -0.06992992306569604, 0.0648541522733341, 0.08654508830086936, -0.20392718085915945],
],
"coef_log": [
[-0.039451752022555374, -0.0003854857497984469, -2.3239743517545694e-05, 9.268231255609287e-06, -0.0006293642768779598],
[-0.0015843455126219829, -0.00597505925571199, -0.00010878237745062993, -0.0007202157067082326, -0.0012285972839189082],
[-0.0009734332588850447, -0.001936822498506686, -0.027443305577813045, -0.00024645647285300213, -0.00019926772803499236],
[-0.0006024744943890134, -0.002147785787884586, 0.001399812773972361, -0.036472059131277504, 0.00012750772191223904],
[-0.001567815576140436, 0.0013055621917748808, 0.0002487312848513498, -0.0008614874408401778, -0.019870896443806487],
],
},
"305M": {
"intercept": [
2.3392247012746834,
3.1651345666056483,
2.471987105632863,
1.2404678308980266,
3.0887017193916093,
],
"coef_linear": [
[-0.3945995646360234, 0.04212797569256443, 0.3597852823915539, 0.004507385349434609, -0.011821078797535627],
[0.18765244585849242, -0.5607080638027755, 0.16385032928665508, 0.22772936777546302, -0.018524079117843765],
[0.36498559947643294, 0.36326950855260254, -1.247281529045098, 0.3544474379638183, 0.16457898305227767],
[0.11479489142933053, 0.2241274675743544, 0.07534052854957383, -0.4984903999878992, 0.08422751243463553],
[0.1034302676552572, -0.1442429588936119, 0.1542111292796102, 0.10260754874495605, -0.2160059867862149],
],
"coef_log": [
[-0.0389843240976756, 0.0003898475662999871, -0.0012326552175473988, -0.0008170951320506675, -0.0006305864869774297],
[-0.0018382099319297328, -0.004966654576883016, -0.0008004862412949112, -0.0016726743862113481, -0.0014239105552697226],
[-0.0013605116194238868, -0.0029875971857020777, -0.029138080972677064, -0.0016317402099057068, 0.001163162472447215],
[-0.0011594434613557832, -0.0010890215347730992, 0.0008814829783619934, -0.035207303872518685, 0.00014797726343401387],
[-0.0015994149367300917, 0.002183961698325075, -0.0005510279268070304, -0.00175006550618083, -0.020723679414693993],
],
},
"410M": {
"intercept": [
2.2845576475924543,
3.10221083581893,
2.4040537489237623,
1.2320388989073703,
3.0194029194493215,
],
"coef_linear": [
[-0.40161868178180443, 0.04851096556048266, 0.37552617435827934, -0.007771674366947659, -0.014646783770016363],
[0.16564665878501697, -0.5418667012877614, 0.19196166965559713, 0.21461472108487065, -0.030356348237732297],
[0.3827566078799856, 0.34563333912754424, -1.207292188578679, 0.2962894338651403, 0.18261280770604105],
[0.054055822096378214, 0.18257490953749397, 0.25515869822947196, -0.5390509589695227, 0.04726152910616942],
[0.08444528786706501, -0.1235789613695045, 0.1680640099151793, 0.09513264245956578, -0.22406297887230905],
],
"coef_log": [
[-0.03838578515244451, 0.0010474524802569906, -0.0020612475600514644, 0.0001902706294946067, -0.0012861227733191377],
[-0.0012829345286925373, -0.004688819508647834, -0.0016346691987556602, -0.0009769593878491815, -0.002091652532429498],
[-0.0012514981651361474, -0.0022513525226212174, -0.03034764820962916, -0.00021614146982995423, -0.00017869542804964955],
[-0.0010235048622236945, -7.088093693356411e-05, -0.0007898197931760238, -0.033703719578066345, -0.0007674227907133403],
[-0.001315884970387432, 0.0024028568720025913, -0.0013333912060313298, -0.0011109928776418308, -0.021811398074324508],
],
},
"70M": {
"intercept": [
2.7857859114105317,
3.631804815517477,
2.8681805224896912,
1.5890762093073625,
3.585150303901379,
],
"coef_linear": [
[-0.4416445868640977, 0.07076027431299382, 0.302121611856706, 0.021520016473356454, 0.04724268422103695],
[0.21834311045644716, -0.553392474651338, 0.04818153280501451, 0.24033451148528265, 0.04653331990458732],
[0.4651202418236946, 0.2657143234624336, -1.3670739924126776, 0.36683688818700816, 0.2694025389395758],
[0.0850252828881378, 0.2990024146623459, 0.16822113994472984, -0.6239023579571726, 0.07165352046195206],
[0.14123915472576296, -0.13497556494772603, 0.03364295958750099, 0.12770730941446795, -0.16761385878000656],
],
"coef_log": [
[-0.041246477328631105, 0.0006531144880363961, -0.0006596475145338669, -0.00019599814522888677, -0.0015631188541267603],
[-0.0009803943328683558, -0.005672467237098692, -8.71136475631502e-05, -0.0009074144501494191, -0.0019413115294711764],
[-0.0006290227608540234, -0.0005120063062147314, -0.02905249764872596, -0.0007835652353273532, -0.000662736071510837],
[-0.0008408582373940847, -0.0019902435925866755, 0.00039188729846795716, -0.0409361614036341, -0.0005053380487605633],
[-0.0009055762689076869, 0.0025986071507895507, 0.00015711172405491434, -0.0011903203768379186, -0.019717110434476673],
],
},
}
def _predict_point(p: Dict[str, float], coeff: Dict[str, list]) -> Dict[str, float]:
# Build feature vectors in fixed order
P = [float(p.get(k, 0.0)) for k in PROP_COLS]
logP = [math.log(x + EPS) for x in P]
y = []
for i in range(5):
val = coeff["intercept"][i]
# linear terms
for j in range(5):
val += coeff["coef_linear"][i][j] * P[j]
# log terms
for j in range(5):
val += coeff["coef_log"][i][j] * logP[j]
y.append(val)
return {LOSS_COLS[i]: y[i] for i in range(5)}
def law(input_data: list[dict[str, float]], group: str) -> list[dict[str, float]]:
"""
Predicts output variables based on input variables according to a discovered scaling law.
Args:
input_data: A list of dictionaries, where each dictionary is a single data
point containing input variable names as keys and their
corresponding values.
group: The name of the experimental group for which to make predictions.
The functional form of the law must be the same for all groups,
but the constant parameters/coefficients can differ per group.
Returns:
A list of dictionaries, corresponding to the input_data list, with each
dictionary containing the predicted output variable(s).
"""
if group not in COEFFS:
# Fallback: use the average of available groups if unknown label is passed
# This keeps functional form identical while remaining robust.
# Precompute simple average coefficients on the fly.
groups = list(COEFFS.values())
avg = {
"intercept": [sum(g["intercept"][i] for g in groups) / len(groups) for i in range(5)],
"coef_linear": [
[sum(g["coef_linear"][i][j] for g in groups) / len(groups) for j in range(5)]
for i in range(5)
],
"coef_log": [
[sum(g["coef_log"][i][j] for g in groups) / len(groups) for j in range(5)]
for i in range(5)
],
}
coeff = avg
else:
coeff = COEFFS[group]
outputs: List[Dict[str, float]] = []
for row in input_data:
outputs.append(_predict_point(row, coeff))
return outputs