def law(input_data: list[dict[str, float]], group: str) -> list[dict[str, float]]:
"""
Predicts output variables based on input variables according to a discovered scaling law.
Args:
input_data: A list of dictionaries, where each dictionary is a single data
point containing input variable names as keys and their
corresponding values.
group: The name of the experimental group for which to make predictions.
The functional form of the law must be the same for all groups,
but the constant parameters/coefficients can differ per group.
Returns:
A list of dictionaries, corresponding to the input_data list, with each
dictionary containing the predicted output variable(s).
"""
# Fitted parameters for each group (from curve fitting on training data)
# Scaling law: sft_loss = a * (sft_data_size)^(-b) + c
parameters = {
"('MBZUAI/LaMini-GPT-124M', 'flan')": {'a': 6.9738879587354985, 'b': 0.0813415811372061, 'c': 6.209360061651158e-11},
"('MBZUAI/LaMini-GPT-774M', 'flan')": {'a': 5.416194164267256, 'b': 0.0720358167338764, 'c': 7.8874180375913e-12},
"('cerebras/Cerebras-GPT-256M', 'flan')": {'a': 4.497775144341002, 'b': 0.0491972775293856, 'c': 1.955613448956228e-11},
"('cerebras/Cerebras-GPT-1.3B', 'flan')": {'a': 3.726687324322024, 'b': 0.0513362241087863, 'c': 7.71038788371925e-12},
"('facebook/bart-base', 'flan')": {'a': 6.796594760907895, 'b': 0.0855017025058688, 'c': 3.5644296141167816e-12},
"('facebook/bart-large', 'flan')": {'a': 5.121314344755191, 'b': 0.0740782998112432, 'c': 1.3640633310585388e-11},
"('facebook/opt-1.3b', 'flan')": {'a': 3.21710353273244, 'b': 0.0490507139907979, 'c': 1.0557194948504606e-11},
"('facebook/opt-350m', 'flan')": {'a': 4.468174350569338, 'b': 0.0576116621246158, 'c': 3.8573586034728045e-11},
"('facebook/opt-6.7b', 'flan')": {'a': 2.233967790682599, 'b': 0.0191483651848653, 'c': 1.9412693674780712e-11},
"('gpt2', 'flan')": {'a': 7.198398164068978, 'b': 0.0817204499072352, 'c': 3.2510296974217244e-12},
"('t5-base', 'flan')": {'a': 3.5430516811202155, 'b': 0.0522550981730723, 'c': 1.1068146399395571e-11},
"('t5-small', 'flan')": {'a': 4.052493019544314, 'b': 0.052691031082205, 'c': 6.746480749474824e-12},
"('google/mt5-base', 'flan')": {'a': 4.571149014484463, 'b': 0.0636171899686622, 'c': 9.026631196760658e-12},
"('google/mt5-large', 'flan')": {'a': 3.492798575656628, 'b': 0.0528118147665376, 'c': 1.0836333814545436e-11},
"('MBZUAI/LaMini-GPT-124M', 'gigaword')": {'a': 5.923744647294182, 'b': 0.1052086313123345, 'c': 1.5788841830932776e-11},
"('MBZUAI/LaMini-GPT-774M', 'gigaword')": {'a': 5.350688480941363, 'b': 0.1095237488864338, 'c': 5.272610698686065e-12},
"('cerebras/Cerebras-GPT-256M', 'gigaword')": {'a': 5.237426477009148, 'b': 0.1000412217663205, 'c': 7.1741015201617725e-12},
"('cerebras/Cerebras-GPT-1.3B', 'gigaword')": {'a': 4.601802086260434, 'b': 0.0893794216309777, 'c': 7.841164216874931e-12},
"('facebook/bart-base', 'gigaword')": {'a': 7.594465198757044, 'b': 0.1333097873907668, 'c': 1.3554385113419691e-11},
"('facebook/bart-large', 'gigaword')": {'a': 7.683920586651666, 'b': 0.1431066056707549, 'c': 8.407693574380672e-12},
"('facebook/opt-1.3b', 'gigaword')": {'a': 5.229994743123611, 'b': 0.1080529832595854, 'c': 1.787133563253055e-11},
"('facebook/opt-350m', 'gigaword')": {'a': 6.499201552078441, 'b': 0.123318476159955, 'c': 1.4467022271584628e-11},
"('facebook/opt-6.7b', 'gigaword')": {'a': 2.1770797673846456, 'b': 0.0144985156639429, 'c': 1.3240519791679617e-12},
"('gpt2', 'gigaword')": {'a': 6.339049368674285, 'b': 0.1172559464274163, 'c': 3.282936462150453e-11},
"('t5-base', 'gigaword')": {'a': 1.823379316613078, 'b': 0.1674599551268627, 'c': 0.4167409405994949},
"('t5-small', 'gigaword')": {'a': 1.775741745120599, 'b': 0.1343979198744542, 'c': 0.4009175915675014},
"('google/mt5-base', 'gigaword')": {'a': 3.4396037497251, 'b': 0.0316506222790409, 'c': 1.851582046346298e-11},
"('google/mt5-large', 'gigaword')": {'a': 3.5492577898997597, 'b': 0.0368812645537255, 'c': 3.0984992301047484e-11},
"('MBZUAI/LaMini-GPT-124M', 'wikiword')": {'a': 3.7815960318544137, 'b': 0.0640731112559584, 'c': 5.133338198959336e-12},
"('MBZUAI/LaMini-GPT-774M', 'wikiword')": {'a': 2.8815066741907813, 'b': 0.0539025472962201, 'c': 1.1417422562942647e-11},
"('cerebras/Cerebras-GPT-256M', 'wikiword')": {'a': 4.45601962837555, 'b': 0.0690574152830047, 'c': 9.361215376952394e-12},
"('cerebras/Cerebras-GPT-1.3B', 'wikiword')": {'a': 3.167052217392773, 'b': 0.0501048229949453, 'c': 6.587842000399508e-13},
"('facebook/bart-base', 'wikiword')": {'a': 5.86196054927262, 'b': 0.1201853857867522, 'c': 0.2505119015404434},
"('facebook/bart-large', 'wikiword')": {'a': 2.620750839880945, 'b': 0.115203713227391, 'c': 0.7814633075466515},
"('facebook/opt-1.3b', 'wikiword')": {'a': 2.349129117465099, 'b': 0.0417628773316036, 'c': 7.147060465007951e-12},
"('facebook/opt-350m', 'wikiword')": {'a': 3.241838816258498, 'b': 0.0554639976201082, 'c': 2.0400472816702877e-11},
"('facebook/opt-6.7b', 'wikiword')": {'a': 1.7881053529873807, 'b': 0.0422082813415794, 'c': 0.2693284658571291},
"('gpt2', 'wikiword')": {'a': 3.9462672178497136, 'b': 0.0681165715966697, 'c': 7.881931987611646e-12},
"('t5-base', 'wikiword')": {'a': 2.259872089331568, 'b': 0.0445738697036901, 'c': 9.790168427672564e-12},
"('t5-small', 'wikiword')": {'a': 2.7905737239350903, 'b': 0.0507995311827284, 'c': 1.8161462756396424e-12},
"('google/mt5-base', 'wikiword')": {'a': 4.7344533518491385, 'b': 0.0926126262618999, 'c': 9.943976438405106e-13},
"('google/mt5-large', 'wikiword')": {'a': 3.9262973834353896, 'b': 0.0781626896765455, 'c': 1.0543787441757105e-12},
}
# Check if the group exists in our trained parameters
if group not in parameters:
# If group is unseen, we could use average parameters or raise an error
# For robustness, let's use the mean parameters across all groups
import statistics
all_a = [p['a'] for p in parameters.values()]
all_b = [p['b'] for p in parameters.values()]
all_c = [p['c'] for p in parameters.values()]
params = {
'a': statistics.mean(all_a),
'b': statistics.mean(all_b),
'c': statistics.median(all_c) # Use median for c as it has extreme outliers
}
else:
params = parameters[group]
# Apply the scaling law: sft_loss = a * (sft_data_size)^(-b) + c
results = []
for data_point in input_data:
N = data_point['sft_data_size']
predicted_loss = params['a'] * (N ** (-params['b'])) + params['c']
results.append({'sft_loss': predicted_loss})
return results