def law(input_data: list[dict[str, float]], group: str) -> list[dict[str, float]]:
"""
Predicts output variables based on input variables according to a discovered scaling law.
Args:
input_data: A list of dictionaries, where each dictionary is a single data
point containing input variable names as keys and their
corresponding values.
group: The name of the experimental group for which to make predictions.
The functional form of the law must be the same for all groups,
but the constant parameters/coefficients can differ per group.
Returns:
A list of dictionaries, corresponding to the input_data list, with each
dictionary containing the predicted output variable(s).
"""
# Fitted parameters (A, B, C) for each group and domain
# Structure: fitted_params[group][domain_key] = {'A': A_val, 'B': B_val, 'C': C_val}
fitted_params = {
'70M': {
'domain_1': {'A': 0.0488, 'B': 0.0576, 'C': 2.5662},
'domain_2': {'A': 0.0141, 'B': 0.0636, 'C': 3.5963},
'domain_3': {'A': 0.0084, 'B': 0.0155, 'C': 3.0578},
'domain_4': {'A': 0.0288, 'B': 0.0377, 'C': 1.5025},
'domain_5': {'A': 0.1127, 'B': 0.1952, 'C': 3.3529},
},
'160M': {
'domain_1': {'A': 0.0402, 'B': 0.0519, 'C': 2.2834},
'domain_2': {'A': 0.0083, 'B': 0.0445, 'C': 3.2866},
'domain_3': {'A': 0.0073, 'B': 0.0143, 'C': 2.7768},
'domain_4': {'A': 0.0255, 'B': 0.0375, 'C': 1.2831},
'domain_5': {'A': 0.1205, 'B': 0.2034, 'C': 2.9952},
},
'305M': {
'domain_1': {'A': 0.0374, 'B': 0.0498, 'C': 2.1469},
'domain_2': {'A': 0.0097, 'B': 0.0528, 'C': 3.1226},
'domain_3': {'A': 0.0059, 'B': 0.0117, 'C': 2.6482},
'domain_4': {'A': 0.0240, 'B': 0.0370, 'C': 1.1838},
'domain_5': {'A': 0.1097, 'B': 0.1856, 'C': 2.8383},
},
'410M': {
'domain_1': {'A': 0.0350, 'B': 0.0476, 'C': 2.0943},
'domain_2': {'A': 0.0057, 'B': 0.0351, 'C': 3.0684},
'domain_3': {'A': 0.0059, 'B': 0.0115, 'C': 2.5829},
'domain_4': {'A': 0.0241, 'B': 0.0379, 'C': 1.1439},
'domain_5': {'A': 0.1109, 'B': 0.1828, 'C': 2.7604},
},
}
predictions = []
for data_point in input_data:
predicted_losses = {}
for i in range(1, 6):
prop_key = f'proportion_domain_{i}'
loss_key = f'loss_domain_{i}'
domain_key = f'domain_{i}'
if prop_key in data_point and group in fitted_params and domain_key in fitted_params[group]:
proportion = data_point[prop_key]
params = fitted_params[group][domain_key]
A, B, C = params['A'], params['B'], params['C']
# Apply the scaling law: Loss = C + A / (Proportion + B)
# Ensure Proportion + B is not zero; B is fitted to be positive, so this should be safe.
predicted_loss = C + A / (proportion + B)
predicted_losses[loss_key] = predicted_loss
else:
# If proportion data is missing or group/domain parameters are not found,
# we cannot make a prediction for this loss.
predicted_losses[loss_key] = float('nan') # or raise an error, or a default value
predictions.append(predicted_losses)
return predictions