import numpy as np
def law(input_data: list[dict[str, float]], group: str) -> list[dict[str, float]]:
"""
Predicts output variables based on input variables according to a discovered scaling law.
Args:
input_data: A list of dictionaries, where each dictionary is a single data
point containing input variable names as keys and their
corresponding values.
group: The name of the experimental group for which to make predictions.
The functional form of the law must be the same for all groups,
but the constant parameters/coefficients can differ per group.
Returns:
A list of dictionaries, corresponding to the input_data list, with each
dictionary containing the predicted output variable(s).
"""
# Parameters discovered from the training data
# Form: sft_loss = a + b * ln(sft_data_size)
params = {
"('MBZUAI/LaMini-GPT-124M', 'flan')": {
"a": 5.988265549351065,
"b": -0.2858284253987973
},
"('MBZUAI/LaMini-GPT-774M', 'flan')": {
"a": 4.783598287214052,
"b": -0.21223122991733276
},
"('cerebras/Cerebras-GPT-256M', 'flan')": {
"a": 4.210632462544093,
"b": -0.14467116003927433
},
"('cerebras/Cerebras-GPT-1.3B', 'flan')": {
"a": 3.4606313079569837,
"b": -0.12168554599239426
},
"('facebook/bart-base', 'flan')": {
"a": 5.722811837645894,
"b": -0.27885014029903604
},
"('facebook/bart-large', 'flan')": {
"a": 4.453518961526505,
"b": -0.19814331941245988
},
"('facebook/opt-1.3b', 'flan')": {
"a": 3.0035271247006574,
"b": -0.10226900935941804
},
"('facebook/opt-350m', 'flan')": {
"a": 4.096427281007177,
"b": -0.1567771098875299
},
"('facebook/opt-6.7b', 'flan')": {
"a": 2.2060635294933304,
"b": -0.035974453233156484
},
"('gpt2', 'flan')": {
"a": 6.179866386147315,
"b": -0.29617705608594097
},
"('t5-base', 'flan')": {
"a": 3.2855166649939935,
"b": -0.11717228671842463
},
"('t5-small', 'flan')": {
"a": 3.752075115263242,
"b": -0.13444290860067154
},
"('google/mt5-base', 'flan')": {
"a": 4.0988075329513345,
"b": -0.16562129412487037
},
"('google/mt5-large', 'flan')": {
"a": 3.2282318950626876,
"b": -0.11550187851501488
},
"('MBZUAI/LaMini-GPT-124M', 'gigaword')": {
"a": 4.737755542012152,
"b": -0.2633818402656468
},
"('MBZUAI/LaMini-GPT-774M', 'gigaword')": {
"a": 4.2045733043013165,
"b": -0.23807369821093685
},
"('cerebras/Cerebras-GPT-256M', 'gigaword')": {
"a": 4.218447739736505,
"b": -0.22564133332553715
},
"('cerebras/Cerebras-GPT-1.3B', 'gigaword')": {
"a": 3.820355959611436,
"b": -0.19060958285317242
},
"('facebook/bart-base', 'gigaword')": {
"a": 5.410755825604152,
"b": -0.3347248647552073
},
"('facebook/bart-large', 'gigaword')": {
"a": 5.242535980974371,
"b": -0.3339209236977352
},
"('facebook/opt-1.3b', 'gigaword')": {
"a": 4.067383747817735,
"b": -0.22503850880208404
},
"('facebook/opt-350m', 'gigaword')": {
"a": 4.789217216189481,
"b": -0.28416051180558977
},
"('facebook/opt-6.7b', 'gigaword')": {
"a": 2.1626361790690503,
"b": -0.027858065828614262
},
"('gpt2', 'gigaword')": {
"a": 4.8171995484035675,
"b": -0.28087485939518764
},
"('t5-base', 'gigaword')": {
"a": 1.480251152476475,
"b": -0.0696292965522014
},
"('t5-small', 'gigaword')": {
"a": 1.6089334546678618,
"b": -0.07229854246705678
},
"('google/mt5-base', 'gigaword')": {
"a": 3.3355539218016768,
"b": -0.08234869176487448
},
"('google/mt5-large', 'gigaword')": {
"a": 3.4146676286886763,
"b": -0.09533853604323887
},
"('MBZUAI/LaMini-GPT-124M', 'wikiword')": {
"a": 3.391898490213529,
"b": -0.1380796078939516
},
"('MBZUAI/LaMini-GPT-774M', 'wikiword')": {
"a": 2.652870267456121,
"b": -0.09602793909920557
},
"('cerebras/Cerebras-GPT-256M', 'wikiword')": {
"a": 3.9284497987057874,
"b": -0.16694369363861158
},
"('cerebras/Cerebras-GPT-1.3B', 'wikiword')": {
"a": 2.950335143562661,
"b": -0.10204513742983291
},
"('facebook/bart-base', 'wikiword')": {
"a": 4.4926337354168595,
"b": -0.2413552868912743
},
"('facebook/bart-large', 'wikiword')": {
"a": 2.7193492499816334,
"b": -0.10813142202742225
},
"('facebook/opt-1.3b', 'wikiword')": {
"a": 2.227609751673505,
"b": -0.06738256744904991
},
"('facebook/opt-350m', 'wikiword')": {
"a": 2.969051299001184,
"b": -0.10931922565009441
},
"('facebook/opt-6.7b', 'wikiword')": {
"a": 1.9623193491235948,
"b": -0.05153332226677372
},
"('gpt2', 'wikiword')": {
"a": 3.4933775691623454,
"b": -0.14749971599228653
},
"('t5-base', 'wikiword')": {
"a": 2.132300508433401,
"b": -0.06785004652930211
},
"('t5-small', 'wikiword')": {
"a": 2.594400260204647,
"b": -0.09053553061557014
},
"('google/mt5-base', 'wikiword')": {
"a": 3.861648173041152,
"b": -0.19493726211114437
},
"('google/mt5-large', 'wikiword')": {
"a": 3.353099664653985,
"b": -0.15291379226040927
}
}
if group not in params:
# Fallback or default?
# If we encounter an unknown group, we can't do much better than guessing or raising error.
# But for this task, we likely just need to handle the known groups.
# Returning a default of 0 or similar might be safe, but let's just log a warning and use mean params?
# Actually, let's just assume known groups or return empty/error if strict.
# Given the instruction "functional form... same... coefficients differ", maybe we just return 0s if unknown.
# But let's try to be helpful.
print(f"Warning: Unknown group '{group}'. Using default parameters.")
a, b = 0, 0 # Placeholder
else:
a = params[group]["a"]
b = params[group]["b"]
predictions = []
for item in input_data:
x = item.get("sft_data_size")
if x is None:
predictions.append({})
continue
# Apply the law: L = a + b * ln(x)
# Ensure x is valid (>0)
if x <= 0:
pred_y = float('nan')
else:
pred_y = a + b * np.log(x)
predictions.append({"sft_loss": pred_y})
return predictions