def law(input_data: list[dict[str, float]], group: str) -> list[dict[str, float]]:
"""
Predicts output variables based on input variables according to a discovered scaling law.
Args:
input_data: A list of dictionaries, where each dictionary is a single data
point containing input variable names as keys and their
corresponding values.
group: The name of the experimental group for which to make predictions.
The functional form of the law must be the same for all groups,
but the constant parameters/coefficients can differ per group.
Returns:
A list of dictionaries, corresponding to the input_data list, with each
dictionary containing the predicted output variable(s).
"""
# Scaling law for MoE architectures:
# L = a * N^(-alpha) * E^(-beta) + c
# where:
# L = loss_validation (predicted output)
# N = dense_parameter_count (input)
# E = num_experts (input)
# a, alpha, beta, c = fitted parameters
# Parameters fitted on the training data for group 'all_data'
# These parameters were obtained through nonlinear least squares fitting
# achieving R² = 0.958, RMSE = 0.052, MAE = 0.038
parameters = {
'all_data': {
'a': 4.3475562897e+01,
'alpha': 0.1989854424,
'beta': 0.0739826608,
'c': 1.6170169395
}
}
# Get parameters for the specified group
if group not in parameters:
raise ValueError(f"Unknown group: {group}. Available groups: {list(parameters.keys())}")
params = parameters[group]
a = params['a']
alpha = params['alpha']
beta = params['beta']
c = params['c']
# Generate predictions for each input data point
predictions = []
for data_point in input_data:
N = data_point['dense_parameter_count']
E = data_point['num_experts']
# Apply the scaling law formula
loss_pred = a * (N ** (-alpha)) * (E ** (-beta)) + c
predictions.append({'loss_validation': loss_pred})
return predictions