def law(input_data: list[dict[str, float]], group: str) -> list[dict[str, float]]:
"""
Predicts output variables based on input variables according to a discovered scaling law.
Law: For each domain i in {1..5},
loss_domain_i = a_{group,i} + b_{group,i} * (proportion_domain_i) ** alpha_i
where the exponent alpha_i is domain-specific but shared across groups, and
(a_{group,i}, b_{group,i}) are fitted per group.
"""
# Domain-specific exponents shared across groups (fitted once)
alphas = {1: 0.226, 2: 0.272, 3: 0.236, 4: 0.235, 5: 0.343}
# Per-group coefficients a and b for each domain (fitted from the provided dataset)
coeffs = {
"160M": {
1: {"a": 3.0607589078884847, "b": -0.8406224674207222},
2: {"a": 3.471957561424479, "b": -0.23709796451470122},
3: {"a": 3.2856010648519973, "b": -0.7919275425273328},
4: {"a": 1.9632078046951371, "b": -0.8321226336323998},
5: {"a": 3.600060737641489, "b": -0.5302231304455584},
},
"305M": {
1: {"a": 2.896951436073815, "b": -0.8170959564908562},
2: {"a": 3.306317389829822, "b": -0.22521283957225652},
3: {"a": 3.155092174041798, "b": -0.8182930011802386},
4: {"a": 1.8328824818924194, "b": -0.7963908513267552},
5: {"a": 3.4340665068448346, "b": -0.5313252100720468},
},
"410M": {
1: {"a": 2.8291888357597386, "b": -0.8073757705491997},
2: {"a": 3.2297361776335225, "b": -0.21719584738930717},
3: {"a": 3.097659192469288, "b": -0.8335641687702692},
4: {"a": 1.779637332326639, "b": -0.775555774148788},
5: {"a": 3.371561997175875, "b": -0.5469883726664775},
},
"70M": {
1: {"a": 3.4193040905517047, "b": -0.9041352514360005},
2: {"a": 3.8189889954933474, "b": -0.25910738407437617},
3: {"a": 3.600895922417036, "b": -0.8317098214628572},
4: {"a": 2.266520379741139, "b": -0.9332890679011832},
5: {"a": 3.937342662537917, "b": -0.5157344418970146},
},
}
# Fallback: if an unknown group is provided, use the average coefficients across known groups
if group not in coeffs:
groups = list(coeffs.keys())
avg = {}
for i in range(1, 6):
a_vals = [coeffs[g][i]["a"] for g in groups]
b_vals = [coeffs[g][i]["b"] for g in groups]
avg[i] = {"a": sum(a_vals) / len(a_vals), "b": sum(b_vals) / len(b_vals)}
coeffs[group] = avg
out = []
for row in input_data:
pred = {}
for i in range(1, 6):
p = float(row.get(f"proportion_domain_{i}", 0.0))
a = coeffs[group][i]["a"]
b = coeffs[group][i]["b"]
alpha = alphas[i]
pred[f"loss_domain_{i}"] = a + b * (p ** alpha)
out.append(pred)
return out
from __future__ import annotations
import math
from typing import Dict, List
# Discovered functional form (same for all groups):
# loss_domain_i = a[group][i] + b[group][i] * log(proportion_domain_i + EPS)
# A small epsilon handles zero proportions.
EPS = 0.003125 # min positive proportion observed / 10
# Fitted coefficients (a, b) per group and domain i in {1..5}
COEFS: Dict[str, Dict[int, tuple[float, float]]] = {
"160M": {
1: (2.2424684059708717, -0.1412039367794934),
2: (3.2615055117992546, -0.036505329124803795),
3: (2.503204375028289, -0.13515973441168078),
4: (1.167057464525046, -0.1374903829439023),
5: (3.1050702723961345, -0.09354145265121051),
},
"305M": {
1: (2.101566717029658, -0.137250480526824),
2: (3.1064301933569487, -0.034670424237437165),
3: (2.347069418836081, -0.1395498701909293),
4: (1.0709224798061698, -0.13158540389099785),
5: (2.9376949827197802, -0.09412677108369148),
},
"410M": {
1: (2.0433633841009002, -0.1355817799554127),
2: (3.036930564759394, -0.033447198672712085),
3: (2.2745949072179825, -0.14214431175511827),
4: (1.0375379829435523, -0.1281673168777551),
5: (2.860506493994419, -0.09695843157676008),
},
"70M": {
1: (2.538957210154492, -0.15195781604593908),
2: (3.589039351229478, -0.03988167092985459),
3: (2.7789903294958576, -0.1420028367096475),
4: (1.3734832511282675, -0.15423705944877764),
5: (3.455746153156199, -0.09113184327935854),
},
}
# Fallback coefficients: average across known groups per domain (used if group not found)
AVG_COEFS = {
i: (
sum(COEFS[g][i][0] for g in COEFS) / len(COEFS),
sum(COEFS[g][i][1] for g in COEFS) / len(COEFS),
)
for i in range(1, 6)
}
def _resolve_group(group: str) -> Dict[int, tuple[float, float]]:
if group in COEFS:
return COEFS[group]
# Try numeric nearest match like '300M' -> closest of known keys
import re
m = re.search(r"(\d+(?:\.\d+)?)", group)
if m:
target = float(m.group(1))
def num(k: str) -> float:
mk = re.search(r"(\d+(?:\.\d+)?)", k)
return float(mk.group(1)) if mk else float("inf")
nearest = min(COEFS.keys(), key=lambda k: abs(num(k) - target))
return COEFS[nearest]
return AVG_COEFS # last-resort fallback
def law(input_data: list[dict[str, float]], group: str) -> list[dict[str, float]]:
"""
Predicts output variables based on input variables according to a discovered scaling law.
Args:
input_data: A list of dictionaries, where each dictionary is a single data
point containing input variable names as keys and their
corresponding values.
group: The name of the experimental group for which to make predictions.
The functional form of the law is the same for all groups, while
coefficients differ per group.
Returns:
A list of dictionaries, corresponding to the input_data list, with each
dictionary containing the predicted output variable(s).
"""
coefs = _resolve_group(group)
outputs: List[Dict[str, float]] = []
for row in input_data:
out: Dict[str, float] = {}
for i in range(1, 6):
p = float(row.get(f"proportion_domain_{i}", 0.0))
a, b = coefs[i]
out[f"loss_domain_{i}"] = a + b * math.log(p + EPS)
outputs.append(out)
return outputs
from typing import List, Dict
# Quadratic scaling law in the five mixture proportions.
# For group g and domain k (k in 1..5):
# loss_k = b_k(g) + sum_i A_{k,i}(g) * p_i + sum_{i<=j} C_{k,ij}(g) * p_i * p_j
# The feature order for the quadratic terms matches sklearn PolynomialFeatures(degree=2, include_bias=False):
# [p1, p2, p3, p4, p5,
# p1^2, p1*p2, p1*p3, p1*p4, p1*p5,
# p2^2, p2*p3, p2*p4, p2*p5,
# p3^2, p3*p4, p3*p5,
# p4^2, p4*p5,
# p5^2]
# Group-specific coefficients (5 outputs x 20 features) and intercepts (5,)
COEFFS: Dict[str, Dict[str, list]] = {
'70M': {
'intercept': [2.67722332, 3.61228621, 2.31848257, 1.43615424, 3.64312409],
'coef': [
[-1.93535793e+00, 9.21334406e-01, 6.07166334e-01, 1.86436004e-01, 2.20421182e-01,
4.17217173e+00, -2.50227074e+00, -8.71756745e-01, -1.45116852e+00, -1.28233366e+00,
7.50528746e-01, 1.88319583e+00, 7.09251012e-01, 8.06295581e-02, 1.85529330e-01,
-4.30245275e-01, -1.59556803e-01, 1.88349071e-01, 1.17024971e+00, 4.11432370e-01],
[ 1.87287409e-01, -5.24704674e-01, 8.19666923e-02, 2.31939372e-01, 2.35112001e-02,
1.93535049e-01, -6.08606198e-01, 1.18874357e-01, 3.03342030e-01, 1.80142171e-01,
8.76312565e-01, 2.24295717e-02, -4.31900052e-01, -3.82940560e-01, 4.86317090e-03,
-4.36420286e-03, -5.98362049e-02, 2.27452590e-01, 1.37409008e-01, 1.48736786e-01],
[ 6.06770406e-01, 1.02165804e+00, -3.42499870e+00, 8.48286144e-01, 9.48284111e-01,
1.17883254e+00, 9.85326844e-02, -2.30473008e+00, 7.43287853e-01, 8.90847414e-01,
2.23385663e+00, -5.42165653e-01, -6.21380452e-01, -1.47185169e-01, 8.79046349e-01,
-1.00148149e+00, -4.55667827e-01, 1.32253646e+00, 4.05323777e-01, 2.54965916e-01],
[ 2.25923880e-01, 5.66280635e-01, 3.05380370e-01, -1.18219021e+00, 8.46053253e-02,
2.43370013e-02, 8.12183458e-01, -5.58959006e-01, -1.43199391e+00, 1.38035634e+00,
5.92080565e-01, 1.49390820e+00, -3.14016004e+00, 8.08268452e-01, 1.55579420e-01,
-1.08989614e+00, 3.04747891e-01, 7.59681480e+00, -3.11695492e+00, 7.08187563e-01],
[ 9.59227049e-02, 2.87475915e-01, 7.95688772e-02, 2.43121301e-01, -7.06088798e-01,
8.03086961e-02, -2.38179581e-01, 4.09505305e-02, 3.70796209e-01, -1.57953149e-01,
4.77055325e-01, 3.25008377e-01, 2.93142645e-01, -5.69550850e-01, -2.33055409e-02,
-7.62471460e-02, -1.86837343e-01, 1.05339029e-01, -4.49909436e-01, 6.58161980e-01],
],
},
'160M': {
'intercept': [2.38262016, 3.30071035, 2.07434882, 1.22527058, 3.31341378],
'coef': [
[-1.89040168e+00, 9.35907232e-01, 5.60322596e-01, 1.88444025e-01, 2.05727823e-01,
4.04736833e+00, -2.51535118e+00, -8.90374549e-01, -1.39350885e+00, -1.13853543e+00,
7.41695287e-01, 1.84860556e+00, 8.63861779e-01, -2.90421544e-03, 1.62855446e-01,
-4.13720211e-01, -1.47043649e-01, -1.33076166e-02, 1.14511892e+00, 3.49092199e-01],
[ 1.06414210e-01, -4.42355222e-01, 1.03243547e-01, 2.22326118e-01, 1.03713464e-02,
2.92828097e-01, -7.03779930e-01, 8.34271237e-03, 2.86433970e-01, 2.22589362e-01,
8.46832622e-01, 1.48524491e-01, -2.84465358e-01, -4.49467047e-01, 2.13811575e-02,
-2.55730882e-02, -4.94317254e-02, 6.67143569e-02, 1.79216238e-01, 1.07464519e-01],
[ 5.10457626e-01, 1.02965945e+00, -3.28863673e+00, 8.27651985e-01, 9.20867673e-01,
1.21104590e+00, -1.00467440e-01, -2.27536431e+00, 7.92543737e-01, 8.82699738e-01,
2.29667733e+00, -4.60952149e-01, -4.56400842e-01, -2.49197452e-01, 8.27965658e-01,
-1.01341930e+00, -3.66866637e-01, 1.04265170e+00, 4.62276688e-01, 1.91955335e-01],
[ 1.85394175e-01, 5.25641652e-01, 2.36807484e-01, -1.05531359e+00, 1.07470277e-01,
4.73960405e-02, 6.76640181e-01, -4.99158421e-01, -1.30202159e+00, 1.26253796e+00,
4.71471468e-01, 1.33050402e+00, -2.67005203e+00, 7.17078014e-01, 9.60652672e-02,
-1.01089455e+00, 3.20291164e-01, 6.69954159e+00, -2.77188702e+00, 5.79450152e-01],
[ 4.30213821e-03, 3.81322064e-01, 1.04070391e-01, 2.40299130e-01, -7.29993724e-01,
2.03468508e-01, -3.88921326e-01, -6.90314075e-02, 3.50874845e-01, -9.20884817e-02,
5.96457642e-01, 4.54420685e-01, 4.50055879e-01, -7.30690815e-01, -9.18767274e-03,
-1.08473047e-01, -1.63658167e-01, -5.80983095e-02, -3.94060239e-01, 6.50503979e-01],
],
},
'305M': {
'intercept': [2.20308636, 3.13378774, 2.0099542, 1.21090615, 3.12985208],
'coef': [
[-1.36612527e+00, 5.08760751e-01, 2.58565850e-01, 1.92426844e-01, 4.06371822e-01,
3.02406051e+00, -1.36341416e+00, -5.63422763e-01, -9.77163560e-01, -1.48618530e+00,
3.33590741e-01, 7.44731825e-01, 2.29382605e-01, 5.64469737e-01, 1.08830884e-01,
-2.29650547e-01, 1.98076451e-01, 1.70915773e-01, 9.98942574e-01, 1.31068355e-01],
[ 1.42155086e-01, -4.77001278e-01, 5.52264024e-02, 2.48408466e-01, 3.12113236e-02,
2.58768506e-01, -3.66936157e-01, -1.32405322e-02, 1.23555478e-01, 1.40007791e-01,
2.60734738e-01, 5.87994884e-02, -1.60567442e-01, -2.69031906e-01, 2.22863705e-02,
-7.35552021e-03, -5.26340404e-03, 2.24335488e-01, 6.84404625e-02, 9.70583802e-02],
[ 6.18506100e-01, 8.68693819e-01, -3.08914591e+00, 7.89619410e-01, 8.12326585e-01,
1.02643703e+00, 1.14784592e-01, -1.46267084e+00, 1.80033336e-01, 7.59921974e-01,
1.17239480e+00, -3.63996838e-01, -2.03911468e-01, 1.49422736e-01, 1.90418054e-01,
-5.68054905e-01, -8.84841389e-01, 8.13849714e-01, 5.67702734e-01, 2.20120530e-01],
[ 2.12633425e-01, 4.83463627e-01, 2.22311810e-01, -8.90787050e-01, -2.76218129e-02,
-9.29330545e-02, 1.99613262e-01, -1.14370686e-01, -7.24874326e-01, 9.45198229e-01,
1.74912693e-01, 6.25137716e-01, -1.05562696e+00, 5.39426918e-01, 8.52526881e-02,
-3.88847751e-01, 1.51398435e-02, 3.43073712e+00, -2.15217513e+00, 6.24788323e-01],
[ 8.99898182e-02, 2.46945404e-01, 7.24765027e-02, 2.55936575e-01, -6.65348300e-01,
1.05642777e-01, -1.69323103e-02, -2.45232236e-02, 2.31633418e-01, -2.05830842e-01,
2.57824563e-01, 1.81182032e-01, 3.00383115e-01, -4.75511996e-01, 1.31255309e-02,
1.56611021e-03, -9.88739467e-02, 2.04976089e-01, -4.82622158e-01, 5.97490643e-01],
],
},
'410M': {
'intercept': [2.1439425, 3.06228202, 1.93134234, 1.17352651, 3.0537303],
'coef': [
[-1.36744907e+00, 5.27110822e-01, 2.03069104e-01, 2.31281112e-01, 4.05988034e-01,
3.00741578e+00, -1.41296472e+00, -5.90698287e-01, -9.03407996e-01, -1.46779385e+00,
3.36508765e-01, 7.13154207e-01, 3.03122142e-01, 5.87290423e-01, 1.12114917e-01,
-2.17566149e-01, 1.86064416e-01, 6.33336092e-02, 9.85799506e-01, 1.14627539e-01],
[ 1.31827459e-01, -4.52442821e-01, 5.13599152e-03, 2.81922129e-01, 3.35572415e-02,
2.56204289e-01, -4.04028535e-01, -4.49805105e-02, 1.62119123e-01, 1.62513091e-01,
2.84260949e-01, 2.83929538e-02, -1.02570552e-01, -2.58497637e-01, 2.72658790e-02,
-3.64530850e-03, -1.89702227e-03, 1.75991829e-01, 5.00270374e-02, 8.14117726e-02],
[ 6.14742174e-01, 8.65449930e-01, -3.13247704e+00, 8.30315141e-01, 8.21969797e-01,
1.11981200e+00, -7.51476148e-02, -1.43672744e+00, 2.46278198e-01, 7.60527037e-01,
1.21696703e+00, -4.08532829e-01, -8.73310035e-02, 2.19494346e-01, 1.95911154e-01,
-5.50356121e-01, -9.32771804e-01, 6.76874073e-01, 5.44849994e-01, 2.29870224e-01],
[ 1.90124684e-01, 4.83340554e-01, 1.97559385e-01, -8.65266788e-01, -5.75783540e-03,
-7.48247755e-02, 1.37556138e-01, -1.54467045e-01, -5.85386297e-01, 8.67246663e-01,
1.95357715e-01, 6.05246418e-01, -1.03503337e+00, 5.80213650e-01, 9.95312192e-02,
-4.27655041e-01, 7.49038343e-02, 3.29049485e+00, -2.10768693e+00, 5.79564950e-01],
[ 7.34423141e-02, 2.79758269e-01, 1.55572901e-02, 2.98213415e-01, -6.66971288e-01,
1.21456923e-01, -7.12563637e-02, -7.33905276e-02, 2.86905077e-01, -1.90272795e-01,
2.93628007e-01, 1.67735765e-01, 3.62290116e-01, -4.72639254e-01, 1.73275864e-02,
8.76400932e-04, -9.69919342e-02, 1.48793013e-01, -5.00651192e-01, 5.93583887e-01],
],
},
}
NAME_ORDER = [
'proportion_domain_1', 'proportion_domain_2', 'proportion_domain_3',
'proportion_domain_4', 'proportion_domain_5'
]
def _poly2_features(vals: List[float]) -> List[float]:
"""Build polynomial features up to degree 2 for five inputs.
Order matches sklearn PolynomialFeatures(deg=2, include_bias=False).
"""
x1, x2, x3, x4, x5 = vals
feats = [
x1, x2, x3, x4, x5,
x1*x1, x1*x2, x1*x3, x1*x4, x1*x5,
x2*x2, x2*x3, x2*x4, x2*x5,
x3*x3, x3*x4, x3*x5,
x4*x4, x4*x5,
x5*x5,
]
return feats
def _choose_group(group: str) -> str:
if group in COEFFS:
return group
# Fallback: map to nearest known parameterized size by numeric value
import re
m = re.search(r"(\d+)", str(group))
if not m:
return '160M'
val = int(m.group(1))
def gnum(g: str) -> int:
mg = re.search(r"(\d+)", g)
return int(mg.group(1)) if mg else 0
return min(COEFFS.keys(), key=lambda g: abs(gnum(g) - val))
def law(input_data: List[Dict[str, float]], group: str) -> List[Dict[str, float]]:
"""
Predicts output variables based on input variables according to a discovered scaling law.
Args:
input_data: A list of dictionaries, where each dictionary is a single data
point containing input variable names as keys and their
corresponding values.
group: The name of the experimental group for which to make predictions.
The functional form of the law must be the same for all groups,
but the constant parameters/coefficients can differ per group.
Returns:
A list of dictionaries, corresponding to the input_data list, with each
dictionary containing the predicted output variable(s).
"""
grp = _choose_group(group)
params = COEFFS[grp]
W = params['coef'] # 5 x 20
b = params['intercept'] # 5
outputs: List[Dict[str, float]] = []
for row in input_data:
vals = [float(row.get(k, 0.0)) for k in NAME_ORDER]
feats = _poly2_features(vals)
y = [b[k] + sum(W[k][j] * feats[j] for j in range(20)) for k in range(5)]
outputs.append({f'loss_domain_{i+1}': y[i] for i in range(5)})
return outputs
def law(input_data: list[dict[str, float]], group: str) -> list[dict[str, float]]:
"""
Predicts output variables based on input variables according to a discovered scaling law.
The discovered law relates each domain's validation loss to its training mixture
proportion via an offset power-decay form shared across groups:
loss_domain_i = L_{g,i} + C_{g,i} * (proportion_domain_i + eps) ** (-a_{g,i})
where g is the experimental group and i in {1..5} is the domain index.
Coefficients (L, C, a) are group- and domain-specific, fitted from the provided dataset.
For unknown groups, coefficients are obtained by size-aware interpolation (based on
the numeric model size parsed from the group string), falling back to a cross-group
average if size cannot be parsed.
Args:
input_data: List of dicts with keys 'proportion_domain_1'..'proportion_domain_5'.
group: Group name (e.g., '70M', '160M', '305M', '410M'). The functional form is
shared across groups; parameters differ per group.
Returns:
List of dicts, each containing keys 'loss_domain_1'..'loss_domain_5'.
"""
# Fitted parameters per group and domain (pow-decay): y = L + C * (p + eps)^(-a)
# Derived from analysis over /app/data.
PARAMS = {
'70M': {
1: {'L': 2.3407861901608458, 'C': 0.3990228222403043, 'a': 0.07229469195307019},
2: {'L': 3.5878039240958115, 'C': 0.06619054334605488, 'a': 0.0904125332037067},
3: {'L': 3.0397589847616344, 'C': 0.09422712368752768, 'a': 0.12920486135303147},
4: {'L': 1.367615184081782, 'C': 0.2717420228432777, 'a': 0.08687316390039175},
5: {'L': 3.3039469037556852, 'C': 0.20054244651774714, 'a': 0.08493556412221587},
},
'160M': {
1: {'L': 2.110715267629515, 'C': 0.3198800255123924, 'a': 0.07933154901483508},
2: {'L': 3.1803738201935277, 'C': 0.13638680248291965, 'a': 0.05512158489827569},
3: {'L': 2.6798939752704323, 'C': 0.15139248008383838, 'a': 0.10036697294631616},
4: {'L': 1.1773152518677579, 'C': 0.2282095425768926, 'a': 0.08971340404690013},
5: {'L': 2.946602702302544, 'C': 0.20862088187761854, 'a': 0.08441408092133072},
},
'305M': {
1: {'L': 1.9852814769362541, 'C': 0.2993480669803183, 'a': 0.08124887473179328},
2: {'L': 3.023213559842957, 'C': 0.13550118870333341, 'a': 0.053401380352092805},
3: {'L': 2.5983613840895425, 'C': 0.09639191662897745, 'a': 0.12705347977640008},
4: {'L': 1.0610600002430857, 'C': 0.2366422762032485, 'a': 0.08578219015376275},
5: {'L': 2.7799528735574452, 'C': 0.2083464473165256, 'a': 0.08504953874822715},
},
'410M': {
1: {'L': 1.9365715595185018, 'C': 0.2874401995772295, 'a': 0.08276413468982288},
2: {'L': 2.9481641120717006, 'C': 0.13858188362142976, 'a': 0.05155133454157204},
3: {'L': 2.5316764731368075, 'C': 0.0972583597842447, 'a': 0.1275873334591975},
4: {'L': 1.0267251504383115, 'C': 0.23166028198558963, 'a': 0.08550863921076982},
5: {'L': 2.700651481039514, 'C': 0.21075075625956133, 'a': 0.0863440646491576},
},
}
# Cross-group per-domain averages (fallback when group is unknown and size cannot be parsed)
AVG = {
1: {'L': 2.093338623561279, 'C': 0.32642277857756113, 'a': 0.07890981259738035},
2: {'L': 3.1848888540509996, 'C': 0.11916510453843443, 'a': 0.0626217082489118},
3: {'L': 2.712422704314604, 'C': 0.10981747004614704, 'a': 0.1210531618837363},
4: {'L': 1.1581788966577342, 'C': 0.24206353090225213, 'a': 0.08696934932795611},
5: {'L': 2.932788490163797, 'C': 0.20706513299286314, 'a': 0.08518581211023284},
}
# Parse numeric model size (in millions) from group string like '70M', '1.3B', '410M'
def _parse_size_millions(g: str):
if not isinstance(g, str):
return None
s = g.strip().upper()
num = ''
unit = ''
for ch in s:
if (ch.isdigit() or ch == '.' or ch == '+') and unit == '':
num += ch
elif ch.isalpha():
unit += ch
# stop at first non-alnum/decimal
try:
val = float(num) if num else None
except Exception:
val = None
if val is None:
return None
if 'B' in unit:
return val * 1000.0
if 'M' in unit or unit == '':
return val
return None
# Retrieve parameters for a group, with size-aware interpolation when needed
def _get_params_for_group(g: str):
if g in PARAMS:
return PARAMS[g]
# Try interpolation based on parsed size
known = sorted(((k, v) for k, v in PARAMS.items()), key=lambda kv: _parse_size_millions(kv[0]) or float('inf'))
sizes = [
(_parse_size_millions(k) if _parse_size_millions(k) is not None else float('inf'))
for k, _ in known
]
size = _parse_size_millions(g)
if size is None or any(x == float('inf') for x in sizes):
return AVG
# Clamp to range if outside
if size <= sizes[0]:
return known[0][1]
if size >= sizes[-1]:
return known[-1][1]
# Find enclosing bracket for interpolation
lo_idx = 0
for i in range(len(sizes) - 1):
if sizes[i] <= size <= sizes[i + 1]:
lo_idx = i
break
hi_idx = lo_idx + 1
s0, s1 = sizes[lo_idx], sizes[hi_idx]
t = (size - s0) / (s1 - s0) if s1 > s0 else 0.0
# Linear interpolation of each parameter per domain
interp = {}
for di in range(1, 6):
L0 = known[lo_idx][1][di]['L']; L1 = known[hi_idx][1][di]['L']
C0 = known[lo_idx][1][di]['C']; C1 = known[hi_idx][1][di]['C']
a0 = known[lo_idx][1][di]['a']; a1 = known[hi_idx][1][di]['a']
interp[di] = {
'L': L0 + (L1 - L0) * t,
'C': C0 + (C1 - C0) * t,
'a': a0 + (a1 - a0) * t,
}
return interp
coeffs = _get_params_for_group(group)
# Small epsilon to regularize p=0 behavior; chosen to be tiny relative to [0,1]
eps = 1e-6
outputs: list[dict[str, float]] = []
for row in input_data:
out_row: dict[str, float] = {}
for i in range(1, 6):
p = float(row.get(f'proportion_domain_{i}', 0.0))
L = float(coeffs[i]['L'])
C = float(coeffs[i]['C'])
a = float(coeffs[i]['a'])
# Ensure non-negative proportion and stable power
if p < 0.0:
p = 0.0
y = L + C * (p + eps) ** (-a)
out_row[f'loss_domain_{i}'] = float(y)
outputs.append(out_row)
return outputs
from math import log10
from typing import Dict, List
def law(input_data: list[dict[str, float]], group: str) -> list[dict[str, float]]:
"""
Predicts output variables based on input variables according to a discovered scaling law.
The law is a linear model for each domain's validation loss as a function of
the five domain mixture proportions. The functional form is the same across
experimental groups, but coefficients differ by group.
For domain k in {1..5} and proportions p_j that sum to 1:
loss_domain_k = intercept[g][k] + sum_j coef[g][k,j] * proportion_domain_j
To generalize to unseen groups, the group-specific parameters are also
modeled as linear functions of log10(model_size) fitted from the observed
groups. If a group string encodes a size (e.g., "550M", "1.3B"), we use this
mapping to synthesize parameters for that group; otherwise we fall back to
the average of known-group parameters.
Args:
input_data: A list of dictionaries, where each dictionary is a single data
point containing input variable names as keys and their
corresponding values. Expected keys: 'proportion_domain_1' .. 'proportion_domain_5'.
group: The name of the experimental group for which to make predictions.
The functional form of the law is the same for all groups,
but the constant parameters/coefficients can differ per group.
Returns:
A list of dictionaries, corresponding to the input_data list, with each
dictionary containing the predicted output variable(s) with keys
'loss_domain_1' .. 'loss_domain_5'.
"""
# Learned coefficients from the provided dataset (one linear model per loss per known group)
COEFS: Dict[str, Dict] = {
"160M": {
"intercepts": {
"loss_domain_1": 2.645699436335381,
"loss_domain_2": 3.2982774153184558,
"loss_domain_3": 2.132434044675029,
"loss_domain_4": 1.3893657641467703,
"loss_domain_5": 3.3747611697919497,
},
"coefs": {
"loss_domain_1": {
"proportion_domain_1": -1.119495690625232,
"proportion_domain_2": 0.7054294745261317,
"proportion_domain_3": -0.28927157238419526,
"proportion_domain_4": 0.584990034524239,
"proportion_domain_5": 0.11834775395905708,
},
"loss_domain_2": {
"proportion_domain_1": 0.2957017578513571,
"proportion_domain_2": -0.7264047971224965,
"proportion_domain_3": -0.015757968287745085,
"proportion_domain_4": 0.338268574484209,
"proportion_domain_5": 0.10819243307467594,
},
"loss_domain_3": {
"proportion_domain_1": 1.2789627462923447,
"proportion_domain_2": 0.9909852370273848,
"proportion_domain_3": -4.473015910337291,
"proportion_domain_4": 1.1335508900914408,
"proportion_domain_5": 1.0695170369261213,
},
"loss_domain_4": {
"proportion_domain_1": 0.366883293836219,
"proportion_domain_2": 0.6809673154560535,
"proportion_domain_3": 0.3411068917199623,
"proportion_domain_4": -1.7833344940309523,
"proportion_domain_5": 0.3943769930187157,
},
"loss_domain_5": {
"proportion_domain_1": 0.02781842295027701,
"proportion_domain_2": 0.02677570609375618,
"proportion_domain_3": 0.20298408706514612,
"proportion_domain_4": 0.1189855046814393,
"proportion_domain_5": -0.37656372079061845,
},
},
},
"305M": {
"intercepts": {
"loss_domain_1": 2.497559209494426,
"loss_domain_2": 3.14536997892604,
"loss_domain_3": 1.968805837824704,
"loss_domain_4": 1.2833765608574346,
"loss_domain_5": 3.210803311891907,
},
"coefs": {
"loss_domain_1": {
"proportion_domain_1": -1.0843715656873325,
"proportion_domain_2": 0.6721539878426451,
"proportion_domain_3": -0.25559189326785264,
"proportion_domain_4": 0.5586611040857702,
"proportion_domain_5": 0.1091483670267703,
},
"loss_domain_2": {
"proportion_domain_1": 0.2802593171012234,
"proportion_domain_2": -0.705345666382081,
"proportion_domain_3": 0.004212615597317157,
"proportion_domain_4": 0.32402103174261404,
"proportion_domain_5": 0.0968527019409263,
},
"loss_domain_3": {
"proportion_domain_1": 1.2989242757705795,
"proportion_domain_2": 1.0459184205941086,
"proportion_domain_3": -4.612543266518341,
"proportion_domain_4": 1.1685296372435943,
"proportion_domain_5": 1.0991709329100599,
},
"loss_domain_4": {
"proportion_domain_1": 0.35041793228110724,
"proportion_domain_2": 0.628523569343145,
"proportion_domain_3": 0.33852456736125136,
"proportion_domain_4": -1.6973603028352418,
"proportion_domain_5": 0.37989423384973636,
},
"loss_domain_5": {
"proportion_domain_1": 0.021137176480833044,
"proportion_domain_2": 0.024298666101822186,
"proportion_domain_3": 0.22155472081459832,
"proportion_domain_4": 0.11202244079079128,
"proportion_domain_5": -0.37901300418804473,
},
},
},
"410M": {
"intercepts": {
"loss_domain_1": 2.432148695798992,
"loss_domain_2": 3.070321096782969,
"loss_domain_3": 1.889657943760822,
"loss_domain_4": 1.254798446077109,
"loss_domain_5": 3.1335037706175104,
},
"coefs": {
"loss_domain_1": {
"proportion_domain_1": -1.0761151307100763,
"proportion_domain_2": 0.7075633674705835,
"proportion_domain_3": -0.32021933049462503,
"proportion_domain_4": 0.584248540833674,
"proportion_domain_5": 0.10452255290044439,
},
"loss_domain_2": {
"proportion_domain_1": 0.28229210571466595,
"proportion_domain_2": -0.6691020114749978,
"proportion_domain_3": -0.05223402368178615,
"proportion_domain_4": 0.3465500387546539,
"proportion_domain_5": 0.09249389068746447,
},
"loss_domain_3": {
"proportion_domain_1": 1.3285868291288925,
"proportion_domain_2": 1.0730826822922905,
"proportion_domain_3": -4.69986390494191,
"proportion_domain_4": 1.1805807813503215,
"proportion_domain_5": 1.1176136121704066,
},
"loss_domain_4": {
"proportion_domain_1": 0.3110943108385622,
"proportion_domain_2": 0.6544173298915442,
"proportion_domain_3": 0.339609109190628,
"proportion_domain_4": -1.6597557356277692,
"proportion_domain_5": 0.3546349857070334,
},
"loss_domain_5": {
"proportion_domain_1": 0.02023592248594768,
"proportion_domain_2": 0.06340878588569265,
"proportion_domain_3": 0.15519736526922107,
"proportion_domain_4": 0.14343010613675986,
"proportion_domain_5": -0.38227217977762123,
},
},
},
"70M": {
"intercepts": {
"loss_domain_1": 2.9696038495486428,
"loss_domain_2": 3.619059870697084,
"loss_domain_3": 2.3841489786591965,
"loss_domain_4": 1.6263083817663921,
"loss_domain_5": 3.712126062067348,
},
"coefs": {
"loss_domain_1": {
"proportion_domain_1": -1.2024018889698118,
"proportion_domain_2": 0.7225748684796697,
"proportion_domain_3": -0.2714083488994125,
"proportion_domain_4": 0.6116971390006383,
"proportion_domain_5": 0.1395382303889172,
},
"loss_domain_2": {
"proportion_domain_1": 0.31154794402259256,
"proportion_domain_2": -0.7696911526036795,
"proportion_domain_3": -0.029834711619449426,
"proportion_domain_4": 0.3475522996421775,
"proportion_domain_5": 0.14042562055835917,
},
"loss_domain_3": {
"proportion_domain_1": 1.3538168492045202,
"proportion_domain_2": 1.0328737163751711,
"proportion_domain_3": -4.676968208089681,
"proportion_domain_4": 1.1539207687720698,
"proportion_domain_5": 1.1363568737379204,
},
"loss_domain_4": {
"proportion_domain_1": 0.3998862788060258,
"proportion_domain_2": 0.747594278597924,
"proportion_domain_3": 0.40318411868033077,
"proportion_domain_4": -1.9906253596015544,
"proportion_domain_5": 0.439960683517272,
},
"loss_domain_5": {
"proportion_domain_1": 0.040341011044849134,
"proportion_domain_2": 0.014732615064055155,
"proportion_domain_3": 0.18686886153522414,
"proportion_domain_4": 0.11511773826952129,
"proportion_domain_5": -0.3570602259136497,
},
},
},
}
# Parameter scaling w.r.t. model size: each parameter theta is fit as
# theta = beta0 + beta1 * log10(model_size)
PARAM_MAP = {
'intercepts': {
'loss_domain_1': {'beta0': 8.410693789482913, 'beta1': -0.6968428684065728},
'loss_domain_2': {'beta0': 9.159453070486071, 'beta1': -0.709122366641981},
'loss_domain_3': {'beta0': 7.408200748969857, 'beta1': -0.64133116180932},
'loss_domain_4': {'beta0': 5.432810327260841, 'beta1': -0.4880602455582854},
'loss_domain_5': {'beta0': 9.564414879713457, 'beta1': -0.7489965370842978},
},
'coefs': {
'loss_domain_1': {
'proportion_domain_1': {'beta0': -2.497615657116547, 'beta1': 0.16617474851983763},
'proportion_domain_2': {'beta0': 1.022994012405379, 'beta1': -0.03874502688716343},
'proportion_domain_3': {'beta0': -0.013900448056208633, 'beta1': -0.03260965165867218},
'proportion_domain_4': {'beta0': 0.9963146912448967, 'beta1': -0.04964843319473607},
'proportion_domain_5': {'beta0': 0.4922074015224826, 'beta1': -0.0451716367792662},
},
'loss_domain_2': {
'proportion_domain_1': {'beta0': 0.6366841433930845, 'beta1': -0.041541148709937525},
'proportion_domain_2': {'beta0': -1.7192214869138862, 'beta1': 0.12086845627811361},
'proportion_domain_3': {'beta0': 0.01801791154655993, 'beta1': -0.004998619018912918},
'proportion_domain_4': {'beta0': 0.43967439538852, 'beta1': -0.012137270708032335},
'proportion_domain_5': {'beta0': 0.6248450365857227, 'beta1': -0.06219141784123091},
},
'loss_domain_3': {
'proportion_domain_1': {'beta0': 1.623227750301966, 'beta1': -0.03718726484996544},
'proportion_domain_2': {'beta0': 0.5786824938578662, 'beta1': 0.05515336511355352},
'proportion_domain_3': {'beta0': -4.340004860354267, 'beta1': -0.0332577630477254},
'proportion_domain_4': {'beta0': 0.842562300321158, 'beta1': 0.03820434891706093},
'proportion_domain_5': {'beta0': 1.2955323158732615, 'beta1': -0.022912686132921745},
},
'loss_domain_4': {
'proportion_domain_1': {'beta0': 1.2168426320766839, 'beta1': -0.10375482434593292},
'proportion_domain_2': {'beta0': 1.8340679378772013, 'beta1': -0.1395259506615147},
'proportion_domain_3': {'beta0': 1.0335256104698727, 'beta1': -0.08180936075274063},
'proportion_domain_4': {'beta0': -5.334693091404859, 'beta1': 0.428635948410881},
'proportion_domain_5': {'beta0': 1.2502569109811015, 'beta1': -0.10354581265069283},
},
'loss_domain_5': {
'proportion_domain_1': {'beta0': 0.24988780248626472, 'beta1': -0.026851221126731266},
'proportion_domain_2': {'beta0': -0.3681011545143189, 'beta1': 0.048319731256692484},
'proportion_domain_3': {'beta0': 0.30625861824469436, 'beta1': -0.013830485288663432},
'proportion_domain_4': {'beta0': -0.07581595507286405, 'beta1': 0.023918795432970087},
'proportion_domain_5': {'beta0': -0.11222931114377548, 'beta1': -0.03155682027426793},
},
},
}
def _parse_group_size(g: str) -> float | None:
if not isinstance(g, str):
return None
s = g.strip().upper()
try:
if s.endswith('B'):
return float(s[:-1]) * 1e9
if s.endswith('M'):
return float(s[:-1]) * 1e6
if s.isdigit():
return float(s)
except ValueError:
return None
return None
def _params_from_size(model_size: float) -> Dict[str, Dict[str, float]]:
logS = log10(model_size)
losses = [f"loss_domain_{i}" for i in range(1, 6)]
props = [f"proportion_domain_{i}" for i in range(1, 6)]
out = {"intercepts": {}, "coefs": {}}
for l in losses:
b = PARAM_MAP['intercepts'][l]
out["intercepts"][l] = b['beta0'] + b['beta1'] * logS
out["coefs"][l] = {}
for p in props:
bp = PARAM_MAP['coefs'][l][p]
out["coefs"][l][p] = bp['beta0'] + bp['beta1'] * logS
return out
# Choose parameters for this group
if group in COEFS:
group_params = COEFS[group]
else:
size = _parse_group_size(group)
if size is not None and size > 0:
group_params = _params_from_size(size)
else:
# Fallback: average coefficients across known groups
groups = list(COEFS.keys())
losses = [f"loss_domain_{i}" for i in range(1, 6)]
props = [f"proportion_domain_{i}" for i in range(1, 6)]
avg = {"intercepts": {}, "coefs": {}}
for loss in losses:
avg["intercepts"][loss] = sum(COEFS[g]["intercepts"][loss] for g in groups) / len(groups)
avg["coefs"][loss] = {p: sum(COEFS[g]["coefs"][loss][p] for g in groups) / len(groups) for p in props}
group_params = avg
losses = [f"loss_domain_{i}" for i in range(1, 6)]
props = [f"proportion_domain_{i}" for i in range(1, 6)]
outputs: list[dict[str, float]] = []
for row in input_data:
pred: dict[str, float] = {}
for loss in losses:
val = group_params["intercepts"][loss]
for p in props:
val += group_params["coefs"][loss][p] * float(row.get(p, 0.0))
pred[loss] = float(val)
outputs.append(pred)
return outputs