from __future__ import annotations
from typing import List, Dict
# Global exponents shared across groups (selected by cross-validated grid search)
_A = 0.26075
_B = 0.50575
# Per-group coefficients [c0, cN, cS, cNS]
_COEFS: Dict[str, tuple[float, float, float, float]] = {
# c0: asymptotic loss as num_params, parallel_size -> infinity
# cN: amplitude for num_params^{-A}
# cS: amplitude for parallel_size^{-B}
# cNS: interaction amplitude for (num_params^{-A} * parallel_size^{-B})
"pile": (1.39800173, 114.189821, 0.0789779439, 5.29151618),
"stack": (0.764687078, 63.5153262, 0.0446666145, 5.06084916),
}
# Fallback coefficients (mean over known groups) for unseen group names
_mean_coefs = tuple(
sum(vals[i] for vals in _COEFS.values()) / len(_COEFS) for i in range(4)
)
def _predict_single(n: float, s: float, coefs: tuple[float, float, float, float]) -> float:
if n <= 0 or s <= 0:
# Degenerate case: return asymptote
return coefs[0]
n_term = n ** (-_A)
s_term = s ** (-_B)
c0, cN, cS, cNS = coefs
return c0 + cN * n_term + cS * s_term + cNS * (n_term * s_term)
def law(input_data: list[dict[str, float]], group: str) -> list[dict[str, float]]:
"""
Predicts output variables based on input variables according to a discovered scaling law.
The scaling law uses a shared power-law transform in num_params and parallel_size
with per-group amplitudes and intercept:
loss = c0_g + cN_g * num_params^{-A} + cS_g * parallel_size^{-B} + cNS_g * (num_params^{-A} * parallel_size^{-B})
Args:
input_data: List of records with keys: 'num_params' and 'parallel_size'.
group: Group name (e.g., 'stack', 'pile'). Functional form is shared across groups; coefficients vary by group.
Returns:
List of records with one key 'loss' per input row.
"""
coefs = _COEFS.get(group, _mean_coefs)
out: List[Dict[str, float]] = []
for row in input_data:
n = float(row.get("num_params", 0.0))
s = float(row.get("parallel_size", 0.0))
y = _predict_single(n, s, coefs)
out.append({"loss": float(y)})
return out