import numpy as np
def law(input_data: list[dict[str, float]], group: str) -> list[dict[str, float]]:
"""
Predicts output variables based on input variables according to a discovered scaling law.
Args:
input_data: A list of dictionaries, where each dictionary is a single data
point containing input variable names as keys and their
corresponding values.
group: The name of the experimental group for which to make predictions.
The functional form of the law must be the same for all groups,
but the constant parameters/coefficients can differ per group.
Returns:
A list of dictionaries, corresponding to the input_data list, with each
dictionary containing the predicted output variable(s).
"""
# Fitted parameters for the 'all_data' group
# These parameters were derived from fitting the model:
# loss = C1 / (params^a) + C2 / (tokens^b) + C3 / (unique_tokens^c) + C_offset
parameters = {
'all_data': {
'C1': 132.54776896351294,
'a': 0.26912805102623555,
'C2': 34376.40665446305,
'b': 0.4999504059374415,
'C3': 17.02861960948566,
'c': 0.15783847826401667,
'C_offset': 1.6997369875249735
}
}
if group not in parameters:
raise ValueError(f"Group '{group}' not recognized. Available groups: {list(parameters.keys())}")
group_params = parameters[group]
C1, a, C2, b, C3, c, C_offset = (
group_params['C1'], group_params['a'],
group_params['C2'], group_params['b'],
group_params['C3'], group_params['c'],
group_params['C_offset']
)
predictions = []
for data_point in input_data:
params = data_point['params']
tokens = data_point['tokens']
unique_tokens = data_point['unique_tokens']
# Calculate loss using the discovered scaling law
# Using np.power for robustness with floats
predicted_loss = C1 / np.power(params, a) + \
C2 / np.power(tokens, b) + \
C3 / np.power(unique_tokens, c) + \
C_offset
predictions.append({'loss': predicted_loss})
return predictions