from typing import List, Dict
# Fitted parameters per group for the scaling law:
# L = L0 + A*params**(-alpha) + B*tokens**(-beta) + C*unique_tokens**(-gamma)
_FITTED = {
'all_data': {'L0': 1.854266820557524, 'A': 5185.946367775831, 'B': 108444.27084241492, 'C': 14.150551846023221, 'alpha': 0.5065481743109205, 'beta': 0.5635672529657598, 'gamma': 0.12922388708956437}
}
_DEFAULT = {'L0': 1.85427, 'A': 5185.95, 'B': 108444, 'C': 14.1506, 'alpha': 0.506548, 'beta': 0.563567, 'gamma': 0.129224}
def _predict_one(x: Dict[str, float], p: Dict[str, float]) -> Dict[str, float]:
N = float(x.get("params", 0.0))
T = float(x.get("tokens", 0.0))
U = float(x.get("unique_tokens", 0.0))
# Guard against non-positive inputs
N = max(N, 1e-12); T = max(T, 1e-12); U = max(U, 1e-12)
L0 = p["L0"]; A=p["A"]; B=p["B"]; C=p["C"]
alpha=p["alpha"]; beta=p["beta"]; gamma=p["gamma"]
loss = L0 + A*(N**(-alpha)) + B*(T**(-beta)) + C*(U**(-gamma))
return {"loss": float(loss)}
def law(input_data: List[Dict[str, float]], group: str) -> List[Dict[str, float]]:
"""
Predicts output variables based on input variables according to a discovered scaling law.
Args:
input_data: A list of dictionaries, where each dictionary is a single data
point containing input variable names as keys and their
corresponding values.
group: The name of the experimental group for which to make predictions.
The functional form of the law must be the same for all groups,
but the constant parameters/coefficients can differ per group.
Returns:
A list of dictionaries, corresponding to the input_data list, with each
dictionary containing the predicted output variable(s).
"""
params = _FITTED.get(group, _DEFAULT)
return [_predict_one(d, params) for d in input_data]