from __future__ import annotations
from typing import List, Dict
# Fitted parameters per group for the scaling law:
# loss = L_inf + A * num_params**(-alpha) * parallel_size**(-beta)
PARAMS = {
'stack': {'L_inf': 0.4906, 'A': 7.92697458504, 'alpha': 0.122304765784, 'beta': 0.0406343024103},
'pile': {'L_inf': 1.2938, 'A': 40.6611812144, 'alpha': 0.194471508223, 'beta': 0.0568103691424},
}
def _predict_one(x: Dict[str, float], par: Dict[str, float]) -> Dict[str, float]:
n = float(x.get('num_params', 0.0))
p = float(x.get('parallel_size', 1.0))
L_inf = par['L_inf']
A = par['A']
alpha = par['alpha']
beta = par['beta']
# Guardrails
n = max(n, 1e-12)
p = max(p, 1e-12)
y = L_inf + A * (n ** (-alpha)) * (p ** (-beta))
return {'loss': float(y)}
def law(input_data: List[Dict[str, float]], group: str) -> List[Dict[str, float]]:
"""
Predicts output variables based on input variables according to a discovered scaling law.
Args:
input_data: A list of dictionaries, where each dictionary is a single data
point containing input variable names as keys and their
corresponding values.
group: The name of the experimental group for which to make predictions.
The functional form of the law must be the same for all groups,
but the constant parameters/coefficients can differ per group.
Returns:
A list of dictionaries, corresponding to the input_data list, with each
dictionary containing the predicted output variable(s).
"""
par = PARAMS.get(group)
if par is None:
# If unseen group, fall back to average of known parameters
if PARAMS:
import statistics as _st
L_inf = _st.mean(v['L_inf'] for v in PARAMS.values())
A = _st.mean(v['A'] for v in PARAMS.values())
alpha = _st.mean(v['alpha'] for v in PARAMS.values())
beta = _st.mean(v['beta'] for v in PARAMS.values())
par = {'L_inf': L_inf, 'A': A, 'alpha': alpha, 'beta': beta}
else:
par = {'L_inf': 0.0, 'A': 1.0, 'alpha': 0.5, 'beta': 0.5}
return [_predict_one(x, par) for x in input_data]