def law(input_data: list[dict[str, float]], group: str) -> list[dict[str, float]]:
"""
Predicts output variables based on input variables according to a discovered scaling law.
The underlying scaling law is:
loss = a + b/params^α + c/tokens^β + d/unique_tokens^γ
Args:
input_data: A list of dictionaries, where each dictionary is a single data
point containing input variable names as keys and their
corresponding values. Expected keys: 'params', 'tokens', 'unique_tokens'
group: The name of the experimental group for which to make predictions.
The functional form of the law must be the same for all groups,
but the constant parameters/coefficients can differ per group.
Returns:
A list of dictionaries, corresponding to the input_data list, with each
dictionary containing the predicted output variable(s), specifically 'loss'.
"""
# Fitted parameters for the scaling law
# These were obtained by fitting the model to the experimental dataset
params_by_group = {
'all_data': {
'a': 1.854174103052296,
'b': 5185.897600342049,
'alpha': 0.5065474635986845,
'c': 108445.00928304848,
'beta': 0.5635676108042057,
'd': 14.148203751260953,
'gamma': 0.12921116039317365
}
}
# Use provided group, or fall back to 'all_data' if not found
if group in params_by_group:
params = params_by_group[group]
elif group is None or group == '':
params = params_by_group['all_data']
else:
# If unknown group, use the universal parameters from 'all_data'
params = params_by_group['all_data']
a = params['a']
b = params['b']
alpha = params['alpha']
c = params['c']
beta = params['beta']
d = params['d']
gamma = params['gamma']
results = []
for data_point in input_data:
params_val = data_point.get('params', 1.0)
tokens_val = data_point.get('tokens', 1.0)
unique_tokens_val = data_point.get('unique_tokens', 1.0)
# Avoid division by zero and ensure positive values for exponentiation
params_val = max(params_val, 1e-10)
tokens_val = max(tokens_val, 1e-10)
unique_tokens_val = max(unique_tokens_val, 1e-10)
# Apply the scaling law formula
loss = a + b / (params_val ** alpha) + c / (tokens_val ** beta) + d / (unique_tokens_val ** gamma)
results.append({'loss': loss})
return results