from __future__ import annotations
import math
from typing import Dict, List
def law(input_data: List[Dict[str, float]], group: str) -> List[Dict[str, float]]:
"""
Predicts output variables based on input variables according to a discovered scaling law.
The law models each domain's validation loss as the sum of:
- a group- and domain-specific intercept a_i,
- a group- and domain-specific coefficient b_i times log(p_i + eps), capturing
diminishing returns from allocating more mixture proportion to the same domain,
- plus a linear combination of the proportions of the other domains (j != i),
with group- and domain-specific coefficients c_{i,j}.
Mathematically, for domain i in {1..5}:
loss_i = a_i + b_i * log(p_i + eps) + sum_{j != i} c_{i,j} * p_j
where p_k are the mixture proportions (sum_k p_k = 1), and eps is a small constant
to handle zero proportions inside the logarithm.
Args:
input_data: List of dicts with keys 'proportion_domain_1'..'proportion_domain_5'.
group: One of the experimental groups. The same functional form is used for all
groups, with coefficients differing per group.
Returns:
A list of dicts with keys 'loss_domain_1'..'loss_domain_5'.
"""
# Small constant to avoid log(0)
EPS = 1e-6
# Coefficients fitted per group on the provided dataset (/app/data), using the
# model: loss_i = a_i + b_i * log(p_i + EPS) + sum_{j != i} c_{i,j} * p_j
# For convenience, linear coefficients are stored as a full 5-length vector per domain
# with 0.0 for the self-domain (j == i) entry.
COEFFS = {
"70M": {
1: {"a": 2.352400, "b": -0.041342, "c": [0.000000, 0.552302, 0.679733, 0.457510, 0.478500]},
2: {"a": 3.119185, "b": -0.005609, "c": [0.733329, 0.000000, 0.567223, 0.760307, 0.571576]},
3: {"a": 1.557687, "b": -0.029500, "c": [1.776484, 1.574088, 0.000000, 1.672027, 1.590520]},
4: {"a": 1.005729, "b": -0.040741, "c": [0.682161, 0.804593, 0.768164, 0.000000, 0.680742]},
5: {"a": 3.401418, "b": -0.019938, "c": [0.282951, 0.204621, 0.280657, 0.244292, 0.000000]},
},
"160M": {
1: {"a": 2.084419, "b": -0.039436, "c": [0.000000, 0.515541, 0.590549, 0.410446, 0.414215]},
2: {"a": 2.848965, "b": -0.005760, "c": [0.664815, 0.000000, 0.533358, 0.698111, 0.486927]},
3: {"a": 1.375788, "b": -0.028472, "c": [1.645880, 1.472320, 0.000000, 1.592583, 1.466833]},
4: {"a": 0.822570, "b": -0.036176, "c": [0.633280, 0.747330, 0.680942, 0.000000, 0.623930]},
5: {"a": 3.044954, "b": -0.020112, "c": [0.288934, 0.234711, 0.313982, 0.265677, 0.000000]},
},
"305M": {
1: {"a": 1.965386, "b": -0.039011, "c": [0.000000, 0.461256, 0.591688, 0.362942, 0.378769]},
2: {"a": 2.675656, "b": -0.004898, "c": [0.681773, 0.000000, 0.558797, 0.717652, 0.506549]},
3: {"a": 1.389474, "b": -0.030900, "c": [1.455301, 1.326467, 0.000000, 1.424874, 1.288538]},
4: {"a": 0.758123, "b": -0.034855, "c": [0.586244, 0.671620, 0.645107, 0.000000, 0.580221]},
5: {"a": 2.880988, "b": -0.021162, "c": [0.278675, 0.225879, 0.321137, 0.249162, 0.000000]},
},
"410M": {
1: {"a": 1.904173, "b": -0.038724, "c": [0.000000, 0.497929, 0.520547, 0.389682, 0.371875]},
2: {"a": 2.648743, "b": -0.005145, "c": [0.632228, 0.000000, 0.458498, 0.688205, 0.451025]},
3: {"a": 1.311117, "b": -0.031575, "c": [1.474932, 1.346313, 0.000000, 1.429078, 1.297670]},
4: {"a": 0.726224, "b": -0.033638, "c": [0.560347, 0.717670, 0.657147, 0.000000, 0.569629]},
5: {"a": 2.802291, "b": -0.021963, "c": [0.276436, 0.261534, 0.247464, 0.274675, 0.000000]},
},
}
# Fallback: if an unknown group is provided, use the closest available group
# by parameterization (default to the smallest model "70M").
params_by_group = COEFFS.get(group)
if params_by_group is None:
params_by_group = COEFFS["70M"]
outputs: List[Dict[str, float]] = []
for row in input_data:
# Read proportions in a fixed order
p = [float(row.get(f"proportion_domain_{i}", 0.0)) for i in range(1, 6)]
# Normalize defensively in case inputs are not perfectly normalized
s = sum(p)
if s > 0:
p = [pi / s for pi in p]
pred: Dict[str, float] = {}
for i in range(1, 6):
par = params_by_group[i]
a = par["a"]
b = par["b"]
c = par["c"] # length-5, zero at index i-1
log_term = math.log(max(p[i - 1], 0.0) + EPS)
linear_term = sum(c[j] * p[j] for j in range(5))
y = a + b * log_term + linear_term
pred[f"loss_domain_{i}"] = float(y)
outputs.append(pred)
return outputs