import numpy as np
def law(input_data: list[dict[str, float]], group: str) -> list[dict[str, float]]:
"""
Predicts output variables based on input variables according to a discovered scaling law.
Args:
input_data: A list of dictionaries, where each dictionary is a single data
point containing input variable names as keys and their
corresponding values.
group: The name of the experimental group for which to make predictions.
The functional form of the law must be the same for all groups,
but the constant parameters/coefficients can differ per group.
Returns:
A list of dictionaries, corresponding to the input_data list, with each
dictionary containing the predicted output variable(s).
"""
# Model parameters (fitted on 'all_data' group)
# These parameters are specific to the group but the functional form remains the same
# Polynomial degree-2 model in log-space
# loss = intercept + c1*log(N) + c2*log(V) + c3*log(D)
# + c4*log²(N) + c5*log(N)*log(V) + c6*log(N)*log(D)
# + c7*log²(V) + c8*log(V)*log(D) + c9*log²(D)
# Parameters for 'all_data' group (the only group in training data)
params = {
'all_data': {
'intercept': 43.653023,
'c1': 0.584601, # log(N)
'c2': 0.779496, # log(V)
'c3': -4.504395, # log(D)
'c4': 0.025814, # log²(N)
'c5': 0.022593, # log(N)*log(V)
'c6': -0.081356, # log(N)*log(D)
'c7': 0.028554, # log²(V)
'c8': -0.073865, # log(V)*log(D)
'c9': 0.137360, # log²(D)
}
}
# Get parameters for the specified group (default to 'all_data' if group not found)
group_params = params.get(group, params['all_data'])
# Extract coefficients
intercept = group_params['intercept']
c1 = group_params['c1']
c2 = group_params['c2']
c3 = group_params['c3']
c4 = group_params['c4']
c5 = group_params['c5']
c6 = group_params['c6']
c7 = group_params['c7']
c8 = group_params['c8']
c9 = group_params['c9']
# Prepare output
results = []
for data_point in input_data:
# Extract input variables
N = data_point['non_vocab_parameters'] # Non-vocabulary parameters
V = data_point['vocab_size'] # Vocabulary size
D = data_point['num_characters'] # Number of characters in training data
# Compute log transformations
log_N = np.log(N)
log_V = np.log(V)
log_D = np.log(D)
# Apply the polynomial scaling law
predicted_loss = (
intercept
+ c1 * log_N
+ c2 * log_V
+ c3 * log_D
+ c4 * log_N**2
+ c5 * log_N * log_V
+ c6 * log_N * log_D
+ c7 * log_V**2
+ c8 * log_V * log_D
+ c9 * log_D**2
)
# Return the predicted output
results.append({
'unigram_normalized_loss': predicted_loss
})
return results