import math
def law(input_data: list[dict[str, float]], group: str) -> list[dict[str, float]]:
"""
Predicts output variables based on input variables according to a discovered scaling law.
Args:
input_data: A list of dictionaries, where each dictionary is a single data
point containing input variable names as keys and their
corresponding values.
group: The name of the experimental group for which to make predictions.
The functional form of the law must be the same for all groups,
but the constant parameters/coefficients can differ per group.
Returns:
A list of dictionaries, corresponding to the input_data list, with each
dictionary containing the predicted output variable(s).
"""
# Fitted parameters for the 'all_data' group
# In a more complex scenario with multiple groups, this would be a dictionary
# mapping group names to their respective parameter sets.
# For this dataset, only 'all_data' group exists.
if group == 'all_data':
params = {
'C': 21.603153428136885,
'lr': 0.00932414933559713,
'bsz': 0.000412360843847143,
'data_size': -0.04811794508621384,
'non_embedding_param_size': -0.05004428969823731
}
else:
# Handle unknown groups, perhaps by raising an error or returning a default
raise ValueError(f"Unknown group: {group}")
predictions = []
for data_point in input_data:
lr = data_point['lr']
bsz = data_point['bsz']
data_size = data_point['data_size']
non_embedding_param_size = data_point['non_embedding_param_size']
# Ensure no input values are zero or negative if they are raised to a non-integer power.
# Although current problem context implies positive values, it's a good practice.
# For this specific problem, lr, bsz, data_size, non_embedding_param_size are expected to be positive.
# Calculate lm_loss using the power law formula
lm_loss = params['C'] * \
(lr**params['lr']) * \
(bsz**params['bsz']) * \
(data_size**params['data_size']) * \
(non_embedding_param_size**params['non_embedding_param_size'])
predictions.append({'lm_loss': lm_loss})
return predictions