# Auto-generated scaling law implementation
# Functional form:
# unigram_normalized_loss = a
# + bV * V^(-alpha) + bP * P^(-beta) + bN * N^(-gamma)
# + bVP * V^(-alpha) * P^(-beta)
# + bVN * V^(-alpha) * N^(-gamma)
# + bPN * P^(-beta) * N^(-gamma)
# where V = vocab_size, P = non_vocab_parameters, N = num_characters.
# Coefficients and exponents are fitted per group; a global fallback is provided.
from typing import List, Dict
import math
_PARAMS = {
"all_data": {
"alpha": 0.25,
"beta": 0.75,
"gamma": 0.5,
"coef": {
"a": -5.662989544286598,
"bV": 2.168436615813606,
"bP": 302653.3106628973,
"bN": 121210.75170193214,
"bVP": 200457.1519513687,
"bVN": -385541.00382365123,
"bPN": -11998348015.944254
},
"mse": 0.0070885392771583644,
"count": 1080
}
}
_GLOBAL = {
"alpha": 0.25,
"beta": 0.75,
"gamma": 0.5,
"coef": {
"a": -5.662989544286598,
"bV": 2.168436615813606,
"bP": 302653.3106628973,
"bN": 121210.75170193214,
"bVP": 200457.1519513687,
"bVN": -385541.00382365123,
"bPN": -11998348015.944254
},
"mse": 0.0070885392771583644,
"count": 1080
}
def _predict_single(x: dict, pars: dict) -> float:
V = float(x.get("vocab_size", 0.0))
P = float(x.get("non_vocab_parameters", 0.0))
N = float(x.get("num_characters", 0.0))
if V <= 0 or P <= 0 or N <= 0 or not all(map(math.isfinite, [V,P,N])):
return float('nan')
alpha = pars["alpha"]; beta = pars["beta"]; gamma = pars["gamma"]
a = pars["coef"]["a"]; bV = pars["coef"]["bV"]; bP = pars["coef"]["bP"]; bN = pars["coef"]["bN"]
bVP = pars["coef"]["bVP"]; bVN = pars["coef"]["bVN"]; bPN = pars["coef"]["bPN"]
fV = V ** (-alpha)
fP = P ** (-beta)
fN = N ** (-gamma)
y = a + bV*fV + bP*fP + bN*fN + bVP*(fV*fP) + bVN*(fV*fN) + bPN*(fP*fN)
return float(y)
def law(input_data: List[Dict[str, float]], group: str) -> List[Dict[str, float]]:
"""
Predicts output variables based on input variables according to a discovered scaling law.
Args:
input_data: A list of dictionaries, where each dictionary is a single data
point containing input variable names as keys and their
corresponding values.
group: The name of the experimental group for which to make predictions.
The functional form of the law is the same for all groups,
but constant parameters/coefficients can differ per group.
Returns:
A list of dictionaries, corresponding to the input_data list, with each
dictionary containing the predicted output variable(s).
"""
pars = _PARAMS.get(str(group), _GLOBAL)
out = []
for x in input_data:
y = _predict_single(x, pars)
out.append({"unigram_normalized_loss": y})
return out