from typing import List, Dict
# Group-specific scaling law coefficients
GROUP_COEFFS: Dict[str, Dict[str, float]] = {
'all_data': {
# lm_loss = a * lr^beta_lr * bsz^beta_bsz * data_size^beta_data * non_embedding_param_size^beta_param
'a': 21.7693221886775,
'beta_lr': 0.00863691905384939,
'beta_bsz': -0.0005162836622543786,
'beta_data': -0.04700957690670226,
'beta_param': -0.05174150134631458,
},
}
def law(input_data: List[Dict[str, float]], group: str) -> List[Dict[str, float]]:
"""
Predicts output variables based on input variables according to a discovered scaling law.
Args:
input_data: A list of dictionaries, where each dictionary is a single data
point containing input variable names as keys and their
corresponding values.
group: The name of the experimental group for which to make predictions.
The functional form of the law is the same for all groups,
with different coefficients per group.
Returns:
A list of dictionaries, corresponding to the input_data list, with each
dictionary containing the predicted output variable(s) under key 'lm_loss'.
"""
if group not in GROUP_COEFFS:
raise ValueError(f"Unknown group '{group}'")
coeffs = GROUP_COEFFS[group]
results: List[Dict[str, float]] = []
for x in input_data:
# compute predicted language modeling loss
pred = (
coeffs['a']
* x['lr'] ** coeffs['beta_lr']
* x['bsz'] ** coeffs['beta_bsz']
* x['data_size'] ** coeffs['beta_data']
* x['non_embedding_param_size'] ** coeffs['beta_param']
)
results.append({'lm_loss': pred})
return results