def law(input_data: list[dict[str, float]], group: str) -> list[dict[str, float]]:
"""
Predicts output variables based on input variables according to a discovered scaling law.
The scaling law is: sft_loss = a - b * log(sft_data_size)
Args:
input_data: A list of dictionaries, where each dictionary is a single data
point containing input variable names as keys and their
corresponding values.
group: The name of the experimental group for which to make predictions.
The functional form of the law must be the same for all groups,
but the constant parameters/coefficients can differ per group.
Returns:
A list of dictionaries, corresponding to the input_data list, with each
dictionary containing the predicted output variable(s).
"""
import math
# Parameters for each group: {a, b}
params = {
"('MBZUAI/LaMini-GPT-124M', 'flan')": {'a': 5.9882655454224425, 'b': 0.28582842496758415},
"('MBZUAI/LaMini-GPT-124M', 'gigaword')": {'a': 4.737755543864644, 'b': 0.2633818404689799},
"('MBZUAI/LaMini-GPT-124M', 'wikiword')": {'a': 3.391898489600237, 'b': 0.1380796078259761},
"('MBZUAI/LaMini-GPT-774M', 'flan')": {'a': 4.783598285711526, 'b': 0.21223122975241257},
"('MBZUAI/LaMini-GPT-774M', 'gigaword')": {'a': 4.204573308225127, 'b': 0.23807369864162195},
"('MBZUAI/LaMini-GPT-774M', 'wikiword')": {'a': 2.652870267476408, 'b': 0.09602793910113468},
"('cerebras/Cerebras-GPT-1.3B', 'flan')": {'a': 3.4606313078949356, 'b': 0.12168554598363243},
"('cerebras/Cerebras-GPT-1.3B', 'gigaword')": {'a': 3.820355957753342, 'b': 0.19060958263967723},
"('cerebras/Cerebras-GPT-1.3B', 'wikiword')": {'a': 2.950335145432975, 'b': 0.10204513762986303},
"('cerebras/Cerebras-GPT-256M', 'flan')": {'a': 4.2106324635240435, 'b': 0.14467116014683556},
"('cerebras/Cerebras-GPT-256M', 'gigaword')": {'a': 4.218447739714603, 'b': 0.22564133332313316},
"('cerebras/Cerebras-GPT-256M', 'wikiword')": {'a': 3.928449799282596, 'b': 0.16694369370158185},
"('facebook/bart-base', 'flan')": {'a': 5.722811839840599, 'b': 0.2788501405317051},
"('facebook/bart-base', 'gigaword')": {'a': 5.410755825061724, 'b': 0.3347248646956692},
"('facebook/bart-base', 'wikiword')": {'a': 4.492633730872094, 'b': 0.241355286392432},
"('facebook/bart-large', 'flan')": {'a': 4.453518961316551, 'b': 0.1981433193894148},
"('facebook/bart-large', 'gigaword')": {'a': 5.242535976034646, 'b': 0.33392092315554134},
"('facebook/bart-large', 'wikiword')": {'a': 2.7193492501969665, 'b': 0.1081314220510578},
"('facebook/opt-1.3b', 'flan')": {'a': 3.003527124075447, 'b': 0.10226900929213374},
"('facebook/opt-1.3b', 'gigaword')": {'a': 4.06738375106869, 'b': 0.2250385091589151},
"('facebook/opt-1.3b', 'wikiword')": {'a': 2.227609752184879, 'b': 0.06738256750517932},
"('facebook/opt-350m', 'flan')": {'a': 4.096427283011327, 'b': 0.15677711010750922},
"('facebook/opt-350m', 'gigaword')": {'a': 4.7892172166877485, 'b': 0.28416051186028063},
"('facebook/opt-350m', 'wikiword')": {'a': 2.969051299948199, 'b': 0.10931922575322839},
"('facebook/opt-6.7b', 'flan')": {'a': 2.206063530252997, 'b': 0.03597445331653896},
"('facebook/opt-6.7b', 'gigaword')": {'a': 2.1626361802570546, 'b': 0.02785806595901192},
"('facebook/opt-6.7b', 'wikiword')": {'a': 1.9623193503403864, 'b': 0.05153332239800601},
"('google/mt5-base', 'flan')": {'a': 4.09880753315647, 'b': 0.16562129414738644},
"('google/mt5-base', 'gigaword')": {'a': 3.3355539188668133, 'b': 0.08234869144321573},
"('google/mt5-base', 'wikiword')": {'a': 3.8616481733826316, 'b': 0.19493726215183346},
"('google/mt5-large', 'flan')": {'a': 3.228231895001962, 'b': 0.11550187850708908},
"('google/mt5-large', 'gigaword')": {'a': 3.4146676288749753, 'b': 0.09533853606014771},
"('google/mt5-large', 'wikiword')": {'a': 3.3530996656383256, 'b': 0.15291379237134867},
"('gpt2', 'flan')": {'a': 6.1798663808272165, 'b': 0.29617705498721825},
"('gpt2', 'gigaword')": {'a': 4.817199540073447, 'b': 0.2808748579274315},
"('gpt2', 'wikiword')": {'a': 3.4933775690283317, 'b': 0.14749971597990033},
"('t5-base', 'flan')": {'a': 3.2855166640801268, 'b': 0.1171722866230908},
"('t5-base', 'gigaword')": {'a': 1.4802511540185386, 'b': 0.06962929672117146},
"('t5-base', 'wikiword')": {'a': 2.132300509025821, 'b': 0.06785004659432728},
"('t5-small', 'flan')": {'a': 3.752075117502899, 'b': 0.1344429088427537},
"('t5-small', 'gigaword')": {'a': 1.6089334542349931, 'b': 0.07229854242895724},
"('t5-small', 'wikiword')": {'a': 2.5944002606124483, 'b': 0.09053553066033125},
}
if group not in params:
raise ValueError(f'Unknown group: {group}')
a, b = params[group]['a'], params[group]['b']
results = []
for item in input_data:
sft_data_size = item['sft_data_size']
predicted_loss = a - b * math.log(sft_data_size)
results.append({'sft_loss': predicted_loss})
return results