def law(input_data: list[dict[str, float]], group: str) -> list[dict[str, float]]:
"""
Predicts output variables based on input variables according to a discovered scaling law.
The scaling law is a fourth-degree polynomial model:
brier_score = a*x^4 + b*x^3 + c*x^2 + d*x + e
where x = log_flops and the coefficients a, b, c, d, e are group-specific.
Args:
input_data: A list of dictionaries, where each dictionary is a single data
point containing input variable names as keys and their
corresponding values. Expected to contain 'log_flops'.
group: The name of the experimental group for which to make predictions.
The functional form of the law must be the same for all groups,
but the constant parameters/coefficients can differ per group.
Returns:
A list of dictionaries, corresponding to the input_data list, with each
dictionary containing the predicted output variable(s).
"""
# Fitted parameters for each group (fourth-degree polynomial coefficients)
# Format: coefficients [a, b, c, d, e] for equation: a*x^4 + b*x^3 + c*x^2 + d*x + e
group_params = {
'abstract_narrative_understanding': {
'a': 0.00297357,
'b': -0.01574499,
'c': 0.01733127,
'd': 0.19073961,
'e': -0.55204005
},
'analogical_similarity': {
'a': -0.02879407,
'b': 0.07981207,
'c': -0.04569143,
'd': -0.02378954,
'e': -0.52241379
},
'arc': {
'a': 0.00112476,
'b': 0.01449702,
'c': -0.07559866,
'd': 0.12072136,
'e': -0.08969112
},
'arithmetic': {
'a': -0.15560601,
'b': 0.40313018,
'c': -0.23428989,
'd': -0.02766946,
'e': -0.16636060
},
'conceptual_combinations': {
'a': -0.09769580,
'b': 0.11494576,
'c': -0.00120401,
'd': 0.02484563,
'e': -0.41118812
},
'hellaswag': {
'a': 0.00058158,
'b': 0.01328482,
'c': -0.06662193,
'd': 0.09988958,
'e': -0.05230534
},
'hindu_knowledge': {
'a': 0.03736588,
'b': 0.18072138,
'c': -0.07440077,
'd': -0.13412345,
'e': -0.39668110
},
'mmlu': {
'a': 0.01953715,
'b': 0.01668490,
'c': -0.03592157,
'd': -0.07667280,
'e': -0.46762245
},
'parsinlu_qa_mc': {
'a': 0.00604901,
'b': -0.01045242,
'c': -0.06744806,
'd': 0.11201012,
'e': -0.43112421
}
}
# Get parameters for the specified group
if group not in group_params:
raise ValueError(f"Unknown group: {group}. Available groups: {list(group_params.keys())}")
params = group_params[group]
a, b, c, d, e = params['a'], params['b'], params['c'], params['d'], params['e']
# Make predictions for each data point
results = []
for data_point in input_data:
x = data_point['log_flops']
# Apply the fourth-degree polynomial
brier_score = a * (x ** 4) + b * (x ** 3) + c * (x ** 2) + d * x + e
results.append({'brier_score': brier_score})
return results