def law(input_data: list[dict[str, float]], group: str) -> list[dict[str, float]]:
"""
Predicts output variables based on input variables according to a discovered scaling law.
Args:
input_data: A list of dictionaries, where each dictionary is a single data
point containing input variable names as keys and their
corresponding values.
group: The name of the experimental group for which to make predictions.
The functional form of the law must be the same for all groups,
but the constant parameters/coefficients can differ per group.
Returns:
A list of dictionaries, corresponding to the input_data list, with each
dictionary containing the predicted output variable(s).
"""
# Fitted coefficients for each group (a*x^2 + b*x + c)
group_coefficients = {
'mmlu': {'a': 0.01147626428052324, 'b': -0.06297043488789657, 'c': -0.4803646502198346},
'parsinlu_qa_mc': {'a': -0.05656739537407177, 'b': 0.09890583732640087, 'c': -0.4349507180682014},
'arithmetic': {'a': -0.12997814962868393, 'b': 0.23537009797522834, 'c': -0.2475326777122077},
'hindu_knowledge': {'a': -0.03440238896008072, 'b': -0.031143510554884634, 'c': -0.4103174193780912},
'analogical_similarity': {'a': -0.01917587967269841, 'b': 0.027911287483472436, 'c': -0.5405750537735577},
'conceptual_combinations': {'a': -0.07148356706471519, 'b': 0.096925955228611, 'c': -0.4093455431314182},
'hellaswag': {'a': -0.03367064575568235, 'b': 0.09805145434945435, 'c': -0.06719686154646046},
'arc': {'a': -0.036868206393668744, 'b': 0.11761949039897288, 'c': -0.10711223271542948},
'abstract_narrative_understanding': {'a': -0.0010020957189679732, 'b': 0.18472699005645862, 'c': -0.5431407140744655},
}
if group not in group_coefficients:
# Handle cases where the group is not found.
# For a U-shaped scaling law, a reasonable default might be to return a very high brier score (worse performance)
# or raise an error. Given the problem context of predicting performance, a safe default would be to indicate
# poor performance or NaN if extrapolation is impossible for unseen groups.
# For now, I'll return a list of dictionaries with 'brier_score': None for unknown groups.
# This will need to be decided based on expected behavior for unseen groups in the hidden dataset.
print(f"Warning: Group '{group}' not found in fitted coefficients. Returning None for brier_score.")
return [{'brier_score': None} for _ in input_data]
coeffs = group_coefficients[group]
a, b, c = coeffs['a'], coeffs['b'], coeffs['c']
predictions = []
for data_point in input_data:
log_flops = data_point.get('log_flops')
if log_flops is None:
predictions.append({'brier_score': None}) # Or raise an error
continue
# Calculate brier_score using the quadratic formula
predicted_brier_score = a * (log_flops ** 2) + b * log_flops + c
predictions.append({'brier_score': predicted_brier_score})
return predictions