import numpy as np
def law(input_data: list[dict[str, float]], group: str) -> list[dict[str, float]]:
"""
Predicts output variables based on input variables according to a discovered scaling law.
Args:
input_data: A list of dictionaries, where each dictionary is a single data
point containing input variable names as keys and their
corresponding values.
group: The name of the experimental group for which to make predictions.
The functional form of the law must be the same for all groups,
but the constant parameters/coefficients can differ per group.
Returns:
A list of dictionaries, corresponding to the input_data list, with each
dictionary containing the predicted output variable(s).
"""
# Fitted parameters for the 'all_data' group
# A = 43.47573060740084
# alpha = 0.07398280187051419
# beta = 0.19898571805319756
# L0 = 1.6170183728670307
# Since there's only one group ('all_data'), the parameters are fixed.
# In a multi-group scenario, you would have a dictionary of parameters per group.
params = {
'all_data': {
'A': 43.47573060740084,
'alpha': 0.07398280187051419,
'beta': 0.19898571805319756,
'L0': 1.6170183728670307
}
}
if group not in params:
raise ValueError(f"Unknown group: {group}")
current_params = params[group]
A = current_params['A']
alpha = current_params['alpha']
beta = current_params['beta']
L0 = current_params['L0']
predictions = []
for data_point in input_data:
num_experts = data_point['num_experts']
dense_parameter_count = data_point['dense_parameter_count']
# Apply the scaling law
loss_validation = A * (num_experts ** (-alpha)) * (dense_parameter_count ** (-beta)) + L0
predictions.append({'loss_validation': loss_validation})
return predictions