def law(input_data: list[dict[str, float]], group: str) -> list[dict[str, float]]:
"""
Predicts output variables based on input variables according to a discovered scaling law.
Args:
input_data: A list of dictionaries, where each dictionary is a single data
point containing input variable names as keys and their
corresponding values.
group: The name of the experimental group for which to make predictions.
The functional form of the law must be the same for all groups,
but the constant parameters/coefficients can differ per group.
Returns:
A list of dictionaries, corresponding to the input_data list, with each
dictionary containing the predicted output variable(s).
"""
# Fitted parameters for each group and domain
# Structure: params[group][domain] = {'L_0': ..., 'L_inf': ..., 'C': ..., 'alpha': ...}
params = {
'70M': {
1: {'L_0': 3.414908, 'L_inf': 0.100000, 'C': 2.447942, 'alpha': 0.055697},
2: {'L_0': 3.818429, 'L_inf': 3.293314, 'C': 0.289679, 'alpha': 0.128688},
3: {'L_0': 3.600640, 'L_inf': 0.100000, 'C': 2.809106, 'alpha': 0.031725},
4: {'L_0': 2.266335, 'L_inf': 1.184513, 'C': 0.285163, 'alpha': 0.269767},
5: {'L_0': 3.931742, 'L_inf': 2.131576, 'C': 1.306624, 'alpha': 0.087635},
},
'160M': {
1: {'L_0': 3.060407, 'L_inf': 0.100000, 'C': 2.162561, 'alpha': 0.055281},
2: {'L_0': 3.472137, 'L_inf': 2.968259, 'C': 0.301832, 'alpha': 0.095188},
3: {'L_0': 3.285555, 'L_inf': 0.100000, 'C': 2.538393, 'alpha': 0.031901},
4: {'L_0': 1.963058, 'L_inf': 0.100000, 'C': 1.118286, 'alpha': 0.089544},
5: {'L_0': 3.594913, 'L_inf': 0.100000, 'C': 2.985780, 'alpha': 0.041148},
},
'305M': {
1: {'L_0': 2.898031, 'L_inf': 0.100000, 'C': 2.025209, 'alpha': 0.056080},
2: {'L_0': 3.306184, 'L_inf': 2.808446, 'C': 0.299215, 'alpha': 0.101365},
3: {'L_0': 3.155623, 'L_inf': 0.100000, 'C': 2.414588, 'alpha': 0.030772},
4: {'L_0': 1.832974, 'L_inf': 0.100000, 'C': 1.022308, 'alpha': 0.092419},
5: {'L_0': 3.434413, 'L_inf': 0.100000, 'C': 2.819827, 'alpha': 0.042872},
},
'410M': {
1: {'L_0': 2.831881, 'L_inf': 0.100000, 'C': 1.971811, 'alpha': 0.055358},
2: {'L_0': 3.230276, 'L_inf': 2.748866, 'C': 0.304027, 'alpha': 0.075641},
3: {'L_0': 3.098252, 'L_inf': 0.100000, 'C': 2.348189, 'alpha': 0.031759},
4: {'L_0': 1.779367, 'L_inf': 0.524991, 'C': 0.569721, 'alpha': 0.146568},
5: {'L_0': 3.374611, 'L_inf': 0.100000, 'C': 2.743305, 'alpha': 0.044620},
},
}
def predict_loss(proportion: float, domain: int, group: str) -> float:
"""
Predict the validation loss for a domain given its proportion in training data.
Scaling law:
loss_i = L_0_i if proportion_i = 0
loss_i = L_inf_i + C_i * proportion_i^(-alpha_i) if proportion_i > 0
where:
- L_0_i is the loss when domain i has zero proportion (intrinsic difficulty)
- L_inf_i is the asymptotic loss as proportion_i → ∞
- C_i is a scaling coefficient
- alpha_i is the scaling exponent (how fast loss decreases with proportion)
"""
p = params[group][domain]
if proportion == 0:
return p['L_0']
else:
return p['L_inf'] + p['C'] * (proportion ** (-p['alpha']))
# Generate predictions for each input data point
results = []
for data_point in input_data:
predictions = {}
# Predict loss for each domain
for domain in range(1, 6):
proportion_key = f'proportion_domain_{domain}'
loss_key = f'loss_domain_{domain}'
if proportion_key in data_point:
proportion = data_point[proportion_key]
predicted_loss = predict_loss(proportion, domain, group)
predictions[loss_key] = predicted_loss
results.append(predictions)
return results