from __future__ import annotations
from typing import Dict, List
# Discovered scaling law (shared functional form across groups):
# loss(N, P; g) = L_inf[g] + A[g] / (N**alpha[g] * P**beta[g])
# where
# N = num_params, P = parallel_size, g = group name
# Coefficients were fitted per group on the provided dataset.
_COEFFICIENTS: Dict[str, Dict[str, float]] = {
# group: {L_inf, A, alpha, beta}
"pile": {
"L_inf": 1.5938,
"A": 927.7951136423936,
"alpha": 0.3714475799060386,
"beta": 0.10974874419589042,
},
"stack": {
"L_inf": 0.7906,
"A": 63.84810455867431,
"alpha": 0.25439743787254915,
"beta": 0.08449977586585611,
},
}
# Fallback (used if an unseen group is requested): simple average of known groups
_DEFAULT = {
k: sum(v[k] for v in _COEFFICIENTS.values()) / len(_COEFFICIENTS)
for k in ("L_inf", "A", "alpha", "beta")
}
def _predict_loss(num_params: float, parallel_size: float, coeffs: Dict[str, float]) -> float:
N = float(num_params)
P = float(parallel_size)
if N <= 0 or P <= 0:
# Guard against invalid inputs; return asymptotic value if invalid
return float(coeffs["L_inf"])
return float(coeffs["L_inf"] + coeffs["A"] / (N ** coeffs["alpha"] * P ** coeffs["beta"]))
def law(input_data: List[Dict[str, float]], group: str) -> List[Dict[str, float]]:
"""
Predicts output variables based on input variables according to a discovered scaling law.
Args:
input_data: A list of dictionaries, where each dictionary is a single data
point containing input variable names as keys and their
corresponding values.
group: The name of the experimental group for which to make predictions.
The functional form of the law must be the same for all groups,
but the constant parameters/coefficients can differ per group.
Returns:
A list of dictionaries, corresponding to the input_data list, with each
dictionary containing the predicted output variable(s).
"""
coeffs = _COEFFICIENTS.get(group, _DEFAULT)
outputs: List[Dict[str, float]] = []
for row in input_data:
n = row.get("num_params")
p = row.get("parallel_size")
if n is None or p is None:
raise KeyError("Each input row must contain 'num_params' and 'parallel_size'.")
y = _predict_loss(float(n), float(p), coeffs)
outputs.append({"loss": float(y)})
return outputs