from typing import List, Dict
"""
This module implements a data-constrained scaling law for language model pre-training.
The functional form is:
loss = C * params^{-a} * tokens^{-b} * unique_tokens^{-c}
Coefficients are fitted per experimental group.
"""
# Fitted coefficients per group
_COEFFICIENTS: Dict[str, Dict[str, float]] = {
'all_data': {
'C': 89.03635820053499,
'a': 0.0671315603289598,
'b': 0.05741837292779814,
'c': 0.02821632111651355,
},
}
def law(input_data: List[Dict[str, float]], group: str) -> List[Dict[str, float]]:
"""
Predicts output variables based on input variables according to a discovered scaling law.
Args:
input_data: A list of dictionaries, where each dictionary is a single data
point containing input variable names as keys and their
corresponding values.
group: The name of the experimental group for which to make predictions.
The functional form of the law is the same for all groups,
but the constant parameters/coefficients differ per group.
Returns:
A list of dictionaries, corresponding to the input_data list, with each
dictionary containing the predicted output variable(s) (here, 'loss').
"""
if group not in _COEFFICIENTS:
raise ValueError(f"Unknown group: {group}")
coeffs = _COEFFICIENTS[group]
C = coeffs['C']
a = coeffs['a']
b = coeffs['b']
c = coeffs['c']
predictions: List[Dict[str, float]] = []
for entry in input_data:
p = entry.get('params')
t = entry.get('tokens')
u = entry.get('unique_tokens')
if p is None or t is None or u is None:
raise KeyError("Input data must contain 'params', 'tokens', and 'unique_tokens'.")
loss_pred = C * (p ** (-a)) * (t ** (-b)) * (u ** (-c))
predictions.append({'loss': loss_pred})
return predictions