from __future__ import annotations
from typing import List, Dict
# Discovered functional form (shared across groups):
# L = L0_g + A_g * V^(-alpha) + B_g * Pnv^(-beta) + C_g * C^(-gamma)
# where
# V = vocab_size
# Pnv = non_vocab_parameters
# C = num_characters
# Exponents are shared across groups; coefficients are per-group.
# Exponents (selected via grid search minimizing RMSE)
_ALPHA = 0.2
_BETA = 0.2
_GAMMA = 0.4
# Per-group coefficients fitted on the provided dataset
# Format: group -> (L0, A, B, C)
_COEFS: Dict[str, tuple[float, float, float, float]] = {
# Only one group was present in the dataset; use it as default.
"all_data": (
-5.547737600980133, # L0
-1.8596813255288938, # A (vocab term)
17.1014092331671, # B (non-vocab parameters term)
9830.897391235507, # C (num_characters term)
),
}
_DEFAULT_GROUP = "all_data"
def _predict_one(vocab_size: float, non_vocab_parameters: float, num_characters: float, coefs: tuple[float, float, float, float]) -> float:
# Guard against non-positive inputs for power operations
eps = 1e-12
V = max(float(vocab_size), eps)
Pnv = max(float(non_vocab_parameters), eps)
C = max(float(num_characters), eps)
L0, A, B, Cc = coefs
return (
L0
+ A * (V ** (-_ALPHA))
+ B * (Pnv ** (-_BETA))
+ Cc * (C ** (-_GAMMA))
)
def law(input_data: List[Dict[str, float]], group: str) -> List[Dict[str, float]]:
"""
Predicts output variables based on input variables according to a discovered scaling law.
Args:
input_data: A list of dictionaries, where each dictionary is a single data
point containing input variable names as keys and their
corresponding values.
group: The name of the experimental group for which to make predictions.
The functional form of the law must be the same for all groups,
but the constant parameters/coefficients can differ per group.
Returns:
A list of dictionaries, corresponding to the input_data list, with each
dictionary containing the predicted output variable(s).
"""
coefs = _COEFS.get(group)
if coefs is None:
# Fallback to default if unseen group; preserves functional form
coefs = _COEFS[_DEFAULT_GROUP]
outputs: List[Dict[str, float]] = []
for row in input_data:
V = float(row.get("vocab_size", 0.0))
Pnv = float(row.get("non_vocab_parameters", 0.0))
C = float(row.get("num_characters", 0.0))
pred = _predict_one(V, Pnv, C, coefs)
outputs.append({"unigram_normalized_loss": float(pred)})
return outputs