import math
def law(input_data: list[dict[str, float]], group: str) -> list[dict[str, float]]:
"""
Predicts output variables based on input variables according to a discovered scaling law.
The discovered scaling law follows the form:
unigram_normalized_loss = a*log(vocab_size) + b*log(non_vocab_parameters) + c*log(num_characters) + d
Where:
- a is a shared coefficient across all groups
- b, c, d are parameters that vary by group (where group is identified by vocab_size)
Args:
input_data: A list of dictionaries, where each dictionary is a single data
point containing input variable names as keys and their
corresponding values.
group: The name of the experimental group for which to make predictions.
In this dataset, the group is identified by the vocab_size value.
Returns:
A list of dictionaries, corresponding to the input_data list, with each
dictionary containing the predicted output variable(s).
"""
# Unified coefficient across all groups
a = 0.0634011567
# Group-specific parameters indexed by vocab_size
# Each group has its own b, c, d values
group_params = {
4096.0: {'b': 0.0103157282, 'c': -0.4387568777, 'd': 5.0740243542},
6144.0: {'b': 0.0019714668, 'c': -0.4572441132, 'd': 5.6220287899},
8192.0: {'b': 0.0035061757, 'c': -0.4762909418, 'd': 6.0311809571},
10240.0: {'b': 0.0097884790, 'c': -0.4849630808, 'd': 6.1153025956},
16384.0: {'b': 0.0080820317, 'c': -0.5083907212, 'd': 6.6778105051},
24576.0: {'b': 0.0128570922, 'c': -0.5238156652, 'd': 6.9554173182},
32768.0: {'b': 0.0118341620, 'c': -0.5321262189, 'd': 7.1667365668},
48128.0: {'b': 0.0572019544, 'c': -0.5462202420, 'd': 6.6517005780},
64512.0: {'b': 0.0299395040, 'c': -0.5454710483, 'd': 7.1626047194},
}
results = []
for data_point in input_data:
vocab_size = data_point['vocab_size']
non_vocab_parameters = data_point['non_vocab_parameters']
num_characters = data_point['num_characters']
# Find the group parameters for this vocab_size
# If exact match not found, use the closest vocab_size
if vocab_size in group_params:
params = group_params[vocab_size]
else:
# Find closest vocab_size in available groups
available_sizes = list(group_params.keys())
closest_size = min(available_sizes, key=lambda x: abs(x - vocab_size))
params = group_params[closest_size]
b = params['b']
c = params['c']
d = params['d']
# Calculate prediction using the scaling law
# loss = a*log(vocab_size) + b*log(non_vocab_parameters) + c*log(num_characters) + d
prediction = (
a * math.log(vocab_size) +
b * math.log(non_vocab_parameters) +
c * math.log(num_characters) +
d
)
results.append({'unigram_normalized_loss': prediction})
return results