import numpy as np
def law(input_data: list[dict[str, float]], group: str) -> list[dict[str, float]]:
"""
Predicts output variables based on input variables according to a discovered scaling law.
Args:
input_data: A list of dictionaries, where each dictionary is a single data
point containing input variable names as keys and their
corresponding values.
group: The name of the experimental group for which to make predictions.
The functional form of the law must be the same for all groups,
but the constant parameters/coefficients can differ per group.
Returns:
A list of dictionaries, corresponding to the input_data list, with each
dictionary containing the predicted output variable(s).
"""
# Group-specific parameters
# These were fitted using differential evolution on the training data
group_params = {
'all_data': {
'E': 1.50388274e+00,
'A': 4.79640622e+01,
'alpha': 2.31705082e-01,
'B': 4.96341790e+01,
'beta': 1.87783075e-01,
'C': 2.68264892e-02,
'lr_scale': 9.70220088e-03,
'lr_exp': -1.04437884e-01,
'F': 5.62388220e-03,
'bsz_scale': 4.05875928e-09,
'bsz_exp': -4.83464543e-01,
}
}
# Get parameters for the specified group
if group not in group_params:
# If group not found, use 'all_data' as default
params = group_params['all_data']
else:
params = group_params[group]
# Extract parameters
E = params['E']
A = params['A']
alpha = params['alpha']
B = params['B']
beta = params['beta']
C = params['C']
lr_scale = params['lr_scale']
lr_exp = params['lr_exp']
F = params['F']
bsz_scale = params['bsz_scale']
bsz_exp = params['bsz_exp']
# Make predictions for each data point
results = []
for data_point in input_data:
# Extract input features
N = data_point['non_embedding_param_size']
D = data_point['data_size']
lr = data_point['lr']
bsz = data_point['bsz']
# Calculate optimal learning rate (scales with model size)
lr_opt = lr_scale * (N ** lr_exp)
# Calculate optimal batch size (scales with model size)
bsz_opt = bsz_scale * (N ** bsz_exp)
# Compute the scaling law:
# L = E + A/N^alpha + B/D^beta + C*(log(lr/lr_opt))^2 + F*log(bsz_opt/bsz)
# Base loss (irreducible loss)
base_loss = E
# Model size scaling term (larger models achieve lower loss)
model_term = A / (N ** alpha)
# Data size scaling term (more data achieves lower loss)
data_term = B / (D ** beta)
# Learning rate penalty (quadratic in log space, penalizes deviation from optimal)
lr_penalty = C * ((np.log(lr) - np.log(lr_opt)) ** 2)
# Batch size effect (logarithmic penalty for suboptimal batch size)
bsz_effect = F * np.log(bsz_opt / bsz)
# Total predicted loss
lm_loss = base_loss + model_term + data_term + lr_penalty + bsz_effect
# Return prediction
results.append({'lm_loss': float(lm_loss)})
return results