from __future__ import annotations
from typing import Dict, List
import math
# Discovered scaling law (same functional form for all groups):
# loss = L0 + A * params^(-alpha) + B * Neff^(-beta)
# with an effective data term that accounts for limited uniqueness in the corpus:
# Neff = (tokens * (c * unique_tokens)) / (tokens + c * unique_tokens)
# which behaves like a smooth minimum of tokens and c * unique_tokens.
#
# Fitted coefficients per group. If an unknown group is provided, we fall back to
# the "all_data" coefficients.
_GROUP_COEFFS: Dict[str, Dict[str, float]] = {
# Fitted on the provided dataset using non-linear least squares
# L0, A, alpha, B, beta, c
"all_data": {
"L0": 2.38717219,
"A": 1.60700128e04,
"alpha": 5.81892030e-01,
"B": 9.76230068e03,
"beta": 4.22008080e-01,
"c": 2.54449411e01,
},
}
# Default group to use when the provided group is not found
_DEFAULT_GROUP = "all_data"
def _predict_single(P: float, T: float, U: float, coeffs: Dict[str, float]) -> float:
"""Apply the scaling law for a single data point.
Args:
P: params (parameter count)
T: tokens (total training tokens)
U: unique_tokens (number of unique tokens)
coeffs: dictionary with keys {L0, A, alpha, B, beta, c}
Returns:
Predicted loss (float)
"""
L0 = float(coeffs["L0"]) # irreducible loss floor
A = float(coeffs["A"]) # capacity scaling amplitude
alpha = float(coeffs["alpha"]) # capacity exponent (>0)
B = float(coeffs["B"]) # data scaling amplitude
beta = float(coeffs["beta"]) # data exponent (>0)
c = float(coeffs["c"]) # uniqueness-to-tokens coupling scale
# Numerical safety
eps = 1e-12
P = max(float(P), eps)
T = max(float(T), 0.0)
U = max(float(U), 0.0)
# Effective number of independent tokens (smooth min between T and c*U)
CU = c * U
denom = T + CU
if denom <= eps:
Neff = 0.0
else:
Neff = (T * CU) / denom
# Clamp Neff minimally to avoid division by zero in power with negative exponent
Neff = max(Neff, eps)
loss = L0 + A * (P ** (-abs(alpha))) + B * (Neff ** (-abs(beta)))
return float(loss)
def law(input_data: List[Dict[str, float]], group: str) -> List[Dict[str, float]]:
"""
Predicts output variables based on input variables according to a discovered scaling law.
Args:
input_data: A list of dictionaries, where each dictionary is a single data
point containing input variable names as keys and their
corresponding values. Required keys: 'params', 'tokens', 'unique_tokens'.
group: The name of the experimental group for which to make predictions.
The functional form of the law is the same for all groups,
but the constant parameters/coefficients can differ per group.
Returns:
A list of dictionaries, corresponding to the input_data list, with each
dictionary containing the predicted output variable(s): {'loss': <float>}.
"""
coeffs = _GROUP_COEFFS.get(group, _GROUP_COEFFS[_DEFAULT_GROUP])
outputs: List[Dict[str, float]] = []
for row in input_data:
P = row.get("params")
T = row.get("tokens")
U = row.get("unique_tokens")
if P is None or T is None or U is None:
raise ValueError("Each input dict must contain 'params', 'tokens', and 'unique_tokens'.")
pred = _predict_single(P, T, U, coeffs)
outputs.append({"loss": pred})
return outputs