import numpy as np
def law(input_data: list[dict[str, float]], group: str) -> list[dict[str, float]]:
"""
Predicts output variables based on input variables according to a discovered scaling law.
The underlying model uses a polynomial degree-2 transformation in log-space:
log(lm_loss) = intercept + sum of linear and quadratic terms in log-space features
Args:
input_data: A list of dictionaries, where each dictionary is a single data
point containing input variable names as keys and their
corresponding values.
group: The name of the experimental group for which to make predictions.
The functional form of the law must be the same for all groups,
but the constant parameters/coefficients can differ per group.
Returns:
A list of dictionaries, corresponding to the input_data list, with each
dictionary containing the predicted output variable(s).
"""
# Coefficients from polynomial degree-2 fit in log-space
# These were derived from linear regression on log-transformed features
intercept = 4.074148228884797
coefficients = {
'log_lr': 0.013795306610031,
'log_bsz': 0.139224299881115,
'log_data_size': -0.243356715515550,
'log_non_embedding_param_size': 0.043573332851390,
'log_lr^2': 0.011119851824430,
'log_lr log_bsz': -0.006260814764152,
'log_lr log_data_size': -0.001395292150337,
'log_lr log_non_embedding_param_size': 0.010231103653946,
'log_bsz^2': 0.009278590376023,
'log_bsz log_data_size': -0.008906902516425,
'log_bsz log_non_embedding_param_size': -0.003417998007062,
'log_data_size^2': 0.008885626075669,
'log_data_size log_non_embedding_param_size': -0.009360216068387,
'log_non_embedding_param_size^2': 0.005268771454322,
}
results = []
for data_point in input_data:
# Extract input variables
lr = data_point['lr']
bsz = data_point['bsz']
data_size = data_point['data_size']
non_embedding_param_size = data_point['non_embedding_param_size']
# Transform to log space
log_lr = np.log(lr)
log_bsz = np.log(bsz)
log_data_size = np.log(data_size)
log_non_embedding_param_size = np.log(non_embedding_param_size)
# Compute log-loss using polynomial model
log_lm_loss = intercept
log_lm_loss += coefficients['log_lr'] * log_lr
log_lm_loss += coefficients['log_bsz'] * log_bsz
log_lm_loss += coefficients['log_data_size'] * log_data_size
log_lm_loss += coefficients['log_non_embedding_param_size'] * log_non_embedding_param_size
log_lm_loss += coefficients['log_lr^2'] * (log_lr ** 2)
log_lm_loss += coefficients['log_lr log_bsz'] * (log_lr * log_bsz)
log_lm_loss += coefficients['log_lr log_data_size'] * (log_lr * log_data_size)
log_lm_loss += coefficients['log_lr log_non_embedding_param_size'] * (log_lr * log_non_embedding_param_size)
log_lm_loss += coefficients['log_bsz^2'] * (log_bsz ** 2)
log_lm_loss += coefficients['log_bsz log_data_size'] * (log_bsz * log_data_size)
log_lm_loss += coefficients['log_bsz log_non_embedding_param_size'] * (log_bsz * log_non_embedding_param_size)
log_lm_loss += coefficients['log_data_size^2'] * (log_data_size ** 2)
log_lm_loss += coefficients['log_data_size log_non_embedding_param_size'] * (log_data_size * log_non_embedding_param_size)
log_lm_loss += coefficients['log_non_embedding_param_size^2'] * (log_non_embedding_param_size ** 2)
# Transform back to original space
lm_loss = np.exp(log_lm_loss)
results.append({'lm_loss': float(lm_loss)})
return results