import math
from typing import List, Dict
_COEFS: Dict[str, Dict[str, float]] = {
# coefficients for each experimental group
'all_data': {
# scale factor a, exponent p for num_experts, exponent q for dense parameters
'a': 10.069179203081296,
'p': -0.0274019225275534,
'q': -0.0705124860796179,
},
}
def law(input_data: List[Dict[str, float]], group: str) -> List[Dict[str, float]]:
"""
Predicts output variables based on input variables according to a discovered scaling law.
Args:
input_data: A list of dictionaries, where each dictionary is a single data
point containing input variable names as keys and their
corresponding values.
group: The name of the experimental group for which to make predictions.
The functional form of the law is the same for all groups,
but the constant parameters/coefficients can differ per group.
Returns:
A list of dictionaries, corresponding to the input_data list, with each
dictionary containing the predicted output variable(s) under key
'loss_validation'.
"""
if group not in _COEFS:
raise ValueError(f"Unknown group '{group}'. Available groups: {list(_COEFS.keys())}")
params = _COEFS[group]
a = params['a']
p = params['p']
q = params['q']
predictions: List[Dict[str, float]] = []
for datum in input_data:
E = datum.get('num_experts')
D = datum.get('dense_parameter_count')
# scaling law: loss = a * E^p * D^q
loss_pred = a * (E ** p) * (D ** q)
predictions.append({'loss_validation': loss_pred})
return predictions