from __future__ import annotations
import math
from typing import Dict, List
# Quadratic-in-logs scaling law coefficients per experimental group.
# Fitted on the provided training dataset.
_COEFS: Dict[str, Dict[str, float]] = {
"all_data": {
"const": 43.653023403132735,
"lnV": 0.7794957511938669,
"lnP": 0.5846007123502984,
"lnC": -4.504394566402747,
"lnV2": 0.028553981965242906,
"lnP2": 0.025813565754701645,
"lnC2": 0.13736040362700275,
"lnV_lnP": 0.02259283815603192,
"lnV_lnC": -0.07386461582128809,
"lnP_lnC": -0.08135643672419962,
}
}
_FEATURES = (
"const",
"lnV",
"lnP",
"lnC",
"lnV2",
"lnP2",
"lnC2",
"lnV_lnP",
"lnV_lnC",
"lnP_lnC",
)
def _features(example: Dict[str, float]) -> Dict[str, float]:
v = float(example["vocab_size"]) # V
p = float(example["non_vocab_parameters"]) # P
c = float(example["num_characters"]) # C
# Natural logs; guard against non-positive with tiny epsilon
eps = 1e-12
lnV = math.log(v if v > 0 else eps)
lnP = math.log(p if p > 0 else eps)
lnC = math.log(c if c > 0 else eps)
return {
"const": 1.0,
"lnV": lnV,
"lnP": lnP,
"lnC": lnC,
"lnV2": lnV * lnV,
"lnP2": lnP * lnP,
"lnC2": lnC * lnC,
"lnV_lnP": lnV * lnP,
"lnV_lnC": lnV * lnC,
"lnP_lnC": lnP * lnC,
}
def _predict_one(ex: Dict[str, float], coefs: Dict[str, float]) -> float:
feats = _features(ex)
y = 0.0
for k in _FEATURES:
y += coefs.get(k, 0.0) * feats[k]
return y
def law(input_data: List[Dict[str, float]], group: str) -> List[Dict[str, float]]:
"""
Predicts output variables based on input variables according to a discovered scaling law.
Args:
input_data: A list of dictionaries, where each dictionary is a single data
point containing input variable names as keys and their
corresponding values.
group: The name of the experimental group for which to make predictions.
The functional form of the law must be the same for all groups,
but the constant parameters/coefficients can differ per group.
Returns:
A list of dictionaries, corresponding to the input_data list, with each
dictionary containing the predicted output variable(s).
"""
coefs = _COEFS.get(group, _COEFS.get("all_data", {}))
if not coefs:
raise ValueError(f"No coefficients available for group '{group}' and no default group present.")
outputs: List[Dict[str, float]] = []
for ex in input_data:
y = _predict_one(ex, coefs)
outputs.append({"unigram_normalized_loss": float(y)})
return outputs