def law(input_data: list[dict[str, float]], group: str) -> list[dict[str, float]]:
"""
Predicts output variables based on input variables according to a discovered scaling law.
Args:
input_data: A list of dictionaries, where each dictionary is a single data
point containing input variable names as keys and their
corresponding values.
group: The name of the experimental group for which to make predictions.
The functional form of the law must be the same for all groups,
but the constant parameters/coefficients can differ per group.
Returns:
A list of dictionaries, corresponding to the input_data list, with each
dictionary containing the predicted output variable(s).
"""
# Fitted parameters for each group using quadratic model: brier_score = a*log_flops^2 + b*log_flops + c
# These parameters were obtained by fitting the training data using scipy.optimize.curve_fit
group_parameters = {
'mmlu': {
'a': 0.01147626,
'b': -0.06297043,
'c': -0.48036465
},
'parsinlu_qa_mc': {
'a': -0.05656740,
'b': 0.09890584,
'c': -0.43495072
},
'arithmetic': {
'a': -0.12997815,
'b': 0.23537010,
'c': -0.24753268
},
'hindu_knowledge': {
'a': -0.03440239,
'b': -0.03114351,
'c': -0.41031742
},
'analogical_similarity': {
'a': -0.01917588,
'b': 0.02791129,
'c': -0.54057506
},
'conceptual_combinations': {
'a': -0.07148357,
'b': 0.09692596,
'c': -0.40934554
},
'hellaswag': {
'a': -0.03367065,
'b': 0.09805145,
'c': -0.06719686
},
'arc': {
'a': -0.03686821,
'b': 0.11761949,
'c': -0.10711223
},
'abstract_narrative_understanding': {
'a': -0.00100210,
'b': 0.18472699,
'c': -0.54314071
}
}
# Get parameters for the specified group
if group not in group_parameters:
raise ValueError(f"Unknown group: {group}. Available groups: {list(group_parameters.keys())}")
params = group_parameters[group]
a, b, c = params['a'], params['b'], params['c']
# Apply the quadratic scaling law to each input data point
output_data = []
for data_point in input_data:
log_flops = data_point['log_flops']
# Quadratic formula: brier_score = a * log_flops^2 + b * log_flops + c
brier_score = a * (log_flops ** 2) + b * log_flops + c
output_data.append({'brier_score': brier_score})
return output_data