"""
Scaling law prediction function for model parallelism experiments.
"""
from typing import List, Dict
# Pre-fitted coefficients for each experimental group
_PARAMS_MAPPING: Dict[str, Dict[str, float]] = {
'pile': {
'A': 7.684255507790626,
'alpha': -0.06451510348874577,
'beta': -0.018877634923324237,
},
'stack': {
'A': 4.430089002263504,
'alpha': -0.0663711422624959,
'beta': -0.02210924328367317,
},
}
def law(input_data: List[Dict[str, float]], group: str) -> List[Dict[str, float]]:
"""
Predicts output variables based on input variables according to a discovered scaling law.
Args:
input_data: A list of dictionaries, where each dictionary is a single data
point containing input variable names as keys and their
corresponding values.
group: The name of the experimental group for which to make predictions.
The functional form of the law is the same for all groups,
but the constant parameters/coefficients differ per group.
Returns:
A list of dictionaries, corresponding to the input_data list, with each
dictionary containing the predicted output variable(s) (key 'loss').
"""
if group not in _PARAMS_MAPPING:
raise ValueError(f"Unknown group: {group}")
params = _PARAMS_MAPPING[group]
A = params['A']
alpha = params['alpha']
beta = params['beta']
predictions: List[Dict[str, float]] = []
for entry in input_data:
num_params = entry['num_params']
parallel_size = entry['parallel_size']
loss_pred = A * (num_params ** alpha) * (parallel_size ** beta)
predictions.append({'loss': loss_pred})
return predictions