from __future__ import annotations
from typing import List, Dict
import math
# Coefficients learned from the observed dataset. Same functional form; group-specific coefficients.
_GROUP_COEFS = {'all_data': {'mse': 0.00731547465744217, 'coef': {'const': -2.1279174866216723, 'V_inv_sqrt': -249.95274501751913, 'P_inv_sqrt': 85788.82318771542, 'C_inv_sqrt': -412035.65012517374, 'logV': 0.13411632826121078, 'logP': -0.5280171191793501, 'logC': 0.24657662157971197, 'V_inv_sqrt_logP': 14.031301495867208, 'P_inv_sqrt_logC': -3829.9875212486672, 'C_inv_sqrt_logP': 25651.188832628683}}}
_FEATURES = ['const', 'V_inv_sqrt', 'P_inv_sqrt', 'C_inv_sqrt', 'logV', 'logP', 'logC', 'V_inv_sqrt_logP', 'P_inv_sqrt_logC', 'C_inv_sqrt_logP']
# Feature computation matches the training pipeline
def _compute_features(d: Dict[str, float]) -> list[float]:
V = float(d.get('vocab_size', 0.0))
P = float(d.get('non_vocab_parameters', 0.0))
C = float(d.get('num_characters', 0.0))
eps = 1e-12
V = V if V > eps else eps
P = P if P > eps else eps
C = C if C > eps else eps
feats = {
'const': 1.0,
'V_inv_sqrt': V**(-0.5),
'P_inv_sqrt': P**(-0.5),
'C_inv_sqrt': C**(-0.5),
'logV': math.log(V),
'logP': math.log(P),
'logC': math.log(C),
'V_inv_sqrt_logP': (V**(-0.5))*math.log(P),
'P_inv_sqrt_logC': (P**(-0.5))*math.log(C),
'C_inv_sqrt_logP': (C**(-0.5))*math.log(P),
}
return [feats[name] for name in _FEATURES]
def law(input_data: list[dict[str, float]], group: str) -> list[dict[str, float]]:
"""
Predicts output variables based on input variables according to a discovered scaling law.
Args:
input_data: A list of dictionaries, where each dictionary is a single data
point containing input variable names as keys and their
corresponding values.
group: The name of the experimental group for which to make predictions.
The functional form of the law must be the same for all groups,
but the constant parameters/coefficients can differ per group.
Returns:
A list of dictionaries, corresponding to the input_data list, with each
dictionary containing the predicted output variable(s).
"""
# Select coefficients for the requested group; if unseen, average across known groups
if group in _GROUP_COEFS:
coefs = _GROUP_COEFS[group]['coef']
else:
keys = list(next(iter(_GROUP_COEFS.values()))['coef'].keys())
avg = {k: 0.0 for k in keys}
n = 0
for rec in _GROUP_COEFS.values():
for k, v in rec['coef'].items():
avg[k] += v
n += 1
for k in avg:
avg[k] /= max(n, 1)
coefs = avg
beta = [coefs[name] for name in _FEATURES]
outputs: list[dict[str, float]] = []
for d in input_data:
x = _compute_features(d)
y = sum(b * xi for b, xi in zip(beta, x))
outputs.append({'unigram_normalized_loss': float(y)})
return outputs