def law(input_data: list[dict[str, float]], group: str) -> list[dict[str, float]]:
"""
Predicts output variables based on input variables according to a discovered scaling law.
Args:
input_data: A list of dictionaries, where each dictionary is a single data
point containing input variable names as keys and their
corresponding values.
group: The name of the experimental group for which to make predictions.
The functional form of the law must be the same for all groups,
but the constant parameters/coefficients can differ per group.
Returns:
A list of dictionaries, corresponding to the input_data list, with each
dictionary containing the predicted output variable(s).
"""
# Parameters for each group
# Model: L = E + A/N^α + B/D^β + C/U^γ
# where N = params, D = tokens, U = unique_tokens, L = loss
parameters = {
'all_data': {
'E': 1.8541292226,
'A': 5.1841032365e+03,
'alpha': 0.5065258787,
'B': 1.0843212340e+05,
'beta': 0.5635613914,
'C': 1.4148096648e+01,
'gamma': 0.1292096864
}
}
# Get parameters for the specified group
params = parameters[group]
E = params['E']
A = params['A']
alpha = params['alpha']
B = params['B']
beta = params['beta']
C = params['C']
gamma = params['gamma']
# Compute predictions for each input data point
results = []
for data in input_data:
N = data['params']
D = data['tokens']
U = data['unique_tokens']
# Apply the scaling law
loss = E + A / (N ** alpha) + B / (D ** beta) + C / (U ** gamma)
results.append({'loss': loss})
return results