def law(input_data: list[dict[str, float]], group: str) -> list[dict[str, float]]:
"""
Predicts output variables based on input variables according to a discovered scaling law.
The scaling law models the final language modeling loss as a function of the number of
model parameters and the degree of model parallelism. The discovered law follows the form:
loss = (c - delta * ln(parallel_size)) * num_params^b
where:
- c and b are group-specific coefficients
- delta is a universal coefficient controlling the parallelism benefit
- num_params is the total number of model parameters
- parallel_size is the degree of parallelism (1, 2, 4, etc.)
Args:
input_data: A list of dictionaries, where each dictionary is a single data
point containing input variable names as keys and their
corresponding values. Expected keys: 'num_params', 'parallel_size'
group: The name of the experimental group for which to make predictions.
Valid values: 'stack', 'pile'
Returns:
A list of dictionaries, corresponding to the input_data list, with each
dictionary containing the predicted output variable 'loss'.
"""
import math
# Group-specific coefficients fitted from training data
params = {
'stack': {'c': 4.639114, 'b': -0.068405},
'pile': {'c': 7.632338, 'b': -0.064275}
}
# Universal coefficient for parallelism benefit
delta = 0.123483
# Validate group parameter
if group not in params:
raise ValueError(f"Unknown group: {group}. Valid groups are: {list(params.keys())}")
# Get group-specific parameters
c = params[group]['c']
b = params[group]['b']
# Make predictions for each input data point
results = []
for data_point in input_data:
num_params = float(data_point['num_params'])
parallel_size = float(data_point['parallel_size'])
# Apply the scaling law formula
# loss = (c - delta * ln(parallel_size)) * num_params^b
coefficient = c - delta * math.log(parallel_size)
loss = coefficient * (num_params ** b)
results.append({'loss': loss})
return results