def law(input_data: list[dict[str, float]], group: str) -> list[dict[str, float]]:
"""
Predicts output variables based on input variables according to a discovered scaling law.
Args:
input_data: A list of dictionaries, where each dictionary is a single data
point containing input variable names as keys and their
corresponding values.
group: The name of the experimental group for which to make predictions.
The functional form of the law must be the same for all groups,
but the constant parameters/coefficients can differ per group.
Returns:
A list of dictionaries, corresponding to the input_data list, with each
dictionary containing the predicted output variable(s).
"""
# Scaling law parameters for each group
# Model: loss = a * N^(-b) + c * P^(-d)
# where N = num_params, P = parallel_size
params = {
'stack': {
'a': 77.1529985547,
'b': 0.2687228347,
'c': 0.8221758142,
'd': 0.0296937458
},
'pile': {
'a': 111.9689899826,
'b': 0.2576609995,
'c': 1.4757763183,
'd': 0.0254075589
}
}
# Get parameters for the specified group
if group not in params:
raise ValueError(f"Unknown group: {group}. Available groups: {list(params.keys())}")
a = params[group]['a']
b = params[group]['b']
c = params[group]['c']
d = params[group]['d']
# Compute predictions for each input data point
predictions = []
for data_point in input_data:
N = data_point['num_params']
P = data_point['parallel_size']
# Apply the scaling law: loss = a * N^(-b) + c * P^(-d)
loss = a * (N ** (-b)) + c * (P ** (-d))
predictions.append({'loss': loss})
return predictions