from __future__ import annotations
from typing import List, Dict
# Discovered scaling law (additive power-law with a constant floor):
# y = L + A * Np^{-alpha} + B * D^{-beta} + C * V^{-gamma}
# where
# y = unigram_normalized_loss
# Np = non_vocab_parameters
# D = num_characters
# V = vocab_size
# The functional form is shared across groups; coefficients may vary by group.
# This repository's dataset only contains a single group ("all_data"). We also
# provide a "default" set of parameters that mirrors the same fit.
PARAMS_BY_GROUP = {
# Fit obtained by cross-validated grid-search on exponents
# (alpha, beta, gamma) and least-squares on coefficients with
# nonnegativity encouraged for A, B, C. Vocabulary-size effect
# is negligible in this dataset (C ~ 0), so gamma is included
# for completeness but contributes little.
"all_data": {
# Refined fit (nonnegative A,B,C with intercept re-fit), full-data:
# exponents: alpha=0.06, beta=0.35, gamma=0.02
# coefficients: L=-6.45718219, A=2.51363526, B=3866.31610, C=0.0
"alpha": 0.06,
"beta": 0.35,
"gamma": 0.02,
"L": -6.45718219,
"A": 2.51363526,
"B": 3866.31610,
"C": 0.0,
},
"default": {
"alpha": 0.06,
"beta": 0.35,
"gamma": 0.02,
"L": -6.45718219,
"A": 2.51363526,
"B": 3866.31610,
"C": 0.0,
},
}
def _predict_one(x: Dict[str, float], p: Dict[str, float]) -> float:
# Extract and guard against non-positive inputs (should not occur in sane data)
V = float(x.get("vocab_size", 0.0))
Np = float(x.get("non_vocab_parameters", 0.0))
D = float(x.get("num_characters", 0.0))
eps = 1e-12
V = max(V, eps)
Np = max(Np, eps)
D = max(D, eps)
alpha = p["alpha"]; beta = p["beta"]; gamma = p["gamma"]
L = p["L"]; A = p["A"]; B = p["B"]; C = p["C"]
y = (
L
+ A * (Np ** (-alpha))
+ B * (D ** (-beta))
+ C * (V ** (-gamma))
)
return float(y)
def law(input_data: List[Dict[str, float]], group: str) -> List[Dict[str, float]]:
"""
Predicts output variables based on input variables according to a discovered scaling law.
Args:
input_data: A list of dictionaries, where each dictionary is a single data
point containing input variable names as keys and their
corresponding values.
group: The name of the experimental group for which to make predictions.
The functional form of the law must be the same for all groups,
but the constant parameters/coefficients can differ per group.
Returns:
A list of dictionaries, corresponding to the input_data list, with each
dictionary containing the predicted output variable(s).
"""
params = PARAMS_BY_GROUP.get(group, PARAMS_BY_GROUP["default"]) # fallback if unseen group
out: List[Dict[str, float]] = []
for row in input_data:
y = _predict_one(row, params)
out.append({"unigram_normalized_loss": y})
return out