import math
__all__ = ['law']
# Precomputed parameters for each experimental group
_PARAMS = {
'all_data': {
'theta0': 6.380591236628991,
'alpha': 0.06340183374111474,
'beta': 0.016411064426657424,
'gamma': -0.5017006627222854,
},
}
def law(input_data: list[dict[str, float]], group: str) -> list[dict[str, float]]:
"""
Predicts unigram_normalized_loss based on vocabulary size,
non-vocabulary parameters, and number of characters using
a log-linear scaling law:
loss = theta0 + alpha*ln(vocab_size)
+ beta*ln(non_vocab_parameters)
+ gamma*ln(num_characters)
Args:
input_data: List of data points, each with keys:
'vocab_size', 'non_vocab_parameters', 'num_characters'.
group: Experimental group name; must be one of _PARAMS keys.
Returns:
List of dicts with key 'unigram_normalized_loss' for each input.
"""
if group not in _PARAMS:
raise ValueError(f"Unknown group '{group}'. Available groups: {list(_PARAMS.keys())}")
params = _PARAMS[group]
results = []
for d in input_data:
vs = d['vocab_size']
nv = d['non_vocab_parameters']
nc = d['num_characters']
loss = (
params['theta0']
+ params['alpha'] * math.log(vs)
+ params['beta'] * math.log(nv)
+ params['gamma'] * math.log(nc)
)
results.append({'unigram_normalized_loss': loss})
return results