import math
def law(input_data: list[dict[str, float]], group: str) -> list[dict[str, float]]:
"""
Predicts output variables based on input variables according to a discovered scaling law.
Args:
input_data: A list of dictionaries, where each dictionary is a single data
point containing input variable names as keys and their
corresponding values.
group: The name of the experimental group for which to make predictions.
The functional form of the law must be the same for all groups,
but the constant parameters/coefficients can differ per group.
Returns:
A list of dictionaries, corresponding to the input_data list, with each
dictionary containing the predicted output variable(s).
"""
group_parameters = {
"stack": {
"K": 4.4301,
"a": -0.0664,
"b": -0.0221
},
"pile": {
"K": 7.6843,
"a": -0.0645,
"b": -0.0189
}
}
if group not in group_parameters:
raise ValueError(f"Unknown group: {group}. Supported groups are {list(group_parameters.keys())}")
params = group_parameters[group]
K = params["K"]
a = params["a"]
b = params["b"]
predictions = []
for data_point in input_data:
num_params = data_point.get("num_params")
parallel_size = data_point.get("parallel_size")
if num_params is None or parallel_size is None:
raise ValueError("Input data point must contain 'num_params' and 'parallel_size'.")
# Ensure num_params and parallel_size are positive for log transformation
if num_params <= 0 or parallel_size <= 0:
# Handle non-positive inputs gracefully, perhaps by returning NaN or raising a specific error
# For now, let's raise an error as the law is based on log-transformation
raise ValueError("Input values 'num_params' and 'parallel_size' must be positive.")
predicted_loss = K * (num_params ** a) * (parallel_size ** b)
predictions.append({"loss": predicted_loss})
return predictions