from __future__ import annotations
from typing import Dict, List
def _predict_loss(num_params: float, parallel_size: float, coeffs: Dict[str, float]) -> float:
L_inf = coeffs["L_inf"]
k = coeffs["k"]
alpha = coeffs["alpha"]
gamma = coeffs["gamma"]
# Scaling law: L = L_inf + k * (num_params * parallel_size**gamma) ** (-alpha)
n_eff = float(num_params) * (float(parallel_size) ** float(gamma))
return float(L_inf + k * (n_eff ** (-alpha)))
def law(input_data: List[Dict[str, float]], group: str) -> List[Dict[str, float]]:
"""
Predicts output variables based on input variables according to a discovered scaling law.
Args:
input_data: A list of dictionaries, where each dictionary is a single data
point containing input variable names as keys and their
corresponding values.
group: The name of the experimental group for which to make predictions.
The functional form of the law must be the same for all groups,
but the constant parameters/coefficients can differ per group.
Returns:
A list of dictionaries, corresponding to the input_data list, with each
dictionary containing the predicted output variable(s).
"""
# Fitted coefficients per group (derived from /app/data)
GROUP_COEFFS: Dict[str, Dict[str, float]] = {
# group: { L_inf, k, alpha, gamma }
"pile": {"L_inf": 0.9994680998203908, "k": 14.076803963829533, "alpha": 0.1260671701767887, "gamma": 0.5},
"stack": {"L_inf": 0.6239972998774463, "k": 11.721235202089247, "alpha": 0.15204205303795576, "gamma": 0.5},
}
# Default to an average of known groups if an unseen group is provided
if group in GROUP_COEFFS:
coeffs = GROUP_COEFFS[group]
else:
# Simple average of per-parameter values
vals = list(GROUP_COEFFS.values())
coeffs = {
"L_inf": sum(v["L_inf"] for v in vals) / len(vals),
"k": sum(v["k"] for v in vals) / len(vals),
"alpha": sum(v["alpha"] for v in vals) / len(vals),
"gamma": sum(v["gamma"] for v in vals) / len(vals),
}
outputs: List[Dict[str, float]] = []
for row in input_data:
n = float(row.get("num_params", 0.0))
p = float(row.get("parallel_size", 1.0))
pred_loss = _predict_loss(n, p, coeffs)
outputs.append({"loss": pred_loss})
return outputs