import math
def law(input_data: list[dict[str, float]], group: str) -> list[dict[str, float]]:
"""
Predicts output variables based on input variables according to a discovered scaling law.
This function implements a power law with an interaction term discovered through analysis
of Mixture-of-Experts (MoE) architecture scaling.
The mathematical formula is:
log(loss_validation) = a*log(num_experts) + b*log(dense_parameter_count)
+ c*(log(num_experts)*log(dense_parameter_count)) + const
Where:
- a = -0.100962
- b = -0.084521
- c = 0.004064
- const = 2.563144
Args:
input_data: A list of dictionaries, where each dictionary is a single data
point containing input variable names as keys and their
corresponding values. Expected keys: 'num_experts', 'dense_parameter_count'
group: The name of the experimental group for which to make predictions.
The functional form of the law is the same for all groups,
but parameters can differ per group. For this dataset, only 'all_data' exists.
Returns:
A list of dictionaries, corresponding to the input_data list, with each
dictionary containing the predicted output variable 'loss_validation'.
"""
# Parameters extracted from the fitted model
# These coefficients work for all groups in this dataset (only 'all_data' exists)
params = {
'all_data': {
'a': -0.10096209505615544,
'b': -0.08452076489156325,
'c': 0.004064472374478456,
'const': 2.56314364064211
}
}
# Use parameters for the specified group; fall back to 'all_data' if group not found
if group not in params:
group = 'all_data'
group_params = params[group]
a = group_params['a']
b = group_params['b']
c = group_params['c']
const = group_params['const']
# Make predictions for each input
results = []
for data_point in input_data:
num_experts = data_point['num_experts']
dense_parameter_count = data_point['dense_parameter_count']
# Compute log-space prediction
log_experts = math.log(num_experts)
log_dense = math.log(dense_parameter_count)
interaction = log_experts * log_dense
log_loss = a * log_experts + b * log_dense + c * interaction + const
# Convert back to linear space
loss_validation = math.exp(log_loss)
results.append({'loss_validation': loss_validation})
return results