import numpy as np
def law(input_data: list[dict[str, float]], group: str) -> list[dict[str, float]]:
"""
Predicts output variables based on input variables according to a discovered scaling law.
Args:
input_data: A list of dictionaries, where each dictionary is a single data
point containing input variable names as keys and their
corresponding values.
group: The name of the experimental group for which to make predictions.
The functional form of the law must be the same for all groups,
but the constant parameters/coefficients can differ per group.
Returns:
A list of dictionaries, corresponding to the input_data list, with each
dictionary containing the predicted output variable(s).
"""
# Fitted parameters for the 'all_data' group
# A + B * log(non_vocab_parameters) + C * log(vocab_size) + D * log(num_characters)
# These parameters were obtained by fitting the model to the provided dataset.
# For this specific dataset, only one group 'all_data' was present.
fitted_params = {
'all_data': {
'A': 6.380590666656606,
'B': 0.016411077894625814,
'C': 0.06340182538033912,
'D': -0.501700641788903
}
}
if group not in fitted_params:
raise ValueError(f"Parameters for group '{group}' not found.")
params = fitted_params[group]
predictions = []
for data_point in input_data:
N = data_point['non_vocab_parameters']
V = data_point['vocab_size']
Ch = data_point['num_characters']
# Ensure inputs are positive for log, although they should be for this problem context.
# Add a small epsilon if inputs can be zero to avoid log(0) issues, but typically not needed for LM parameters.
if N <= 0 or V <= 0 or Ch <= 0:
# Handle invalid input, perhaps by returning NaN or raising an error
# For now, let's assume valid positive inputs as per typical LM scaling laws.
# Or, for safety, one could add a small epsilon: np.log(max(1e-9, N))
raise ValueError("Input variables (non_vocab_parameters, vocab_size, num_characters) must be positive.")
predicted_loss = params['A'] + \
params['B'] * np.log(N) + \
params['C'] * np.log(V) + \
params['D'] * np.log(Ch)
predictions.append({'unigram_normalized_loss': predicted_loss})
return predictions