import json
import numpy as np
# Load fitted parameters from the JSON file
# In a real deployment, this would be loaded once, not on every function call.
# For this exercise, we'll assume the file is accessible or parameters are hardcoded
# for simplicity in the submitted law.py.
# However, to simulate the loading from the temporary directory, we'll keep the path.
# For the final law.py, I will embed the parameters directly into the function.
# Hardcoding the parameters into the function for submission to avoid file dependencies in the final `law.py`
FITTED_PARAMS = {
"('MBZUAI/LaMini-GPT-124M', 'flan')": {"A": 6.973888553329567, "B": 0.08134159151639991, "C": 7.870182444280732e-15},
"('MBZUAI/LaMini-GPT-774M', 'flan')": {"A": 5.416195597322399, "B": 0.0720358485868509, "C": 1.0602644871174933e-11},
"('cerebras/Cerebras-GPT-256M', 'flan')": {"A": 4.497775760815009, "B": 0.04919729767645753, "C": 6.274119068597892e-10},
"('cerebras/Cerebras-GPT-1.3B', 'flan')": {"A": 3.7266874945596746, "B": 0.05133622927325513, "C": 1.412382533841796e-10},
"('facebook/bart-base', 'flan')": {"A": 6.796594386775143, "B": 0.08550169575715841, "C": 6.737958498228297e-18},
"('facebook/bart-large', 'flan')": {"A": 5.121314339920321, "B": 0.07407829970041344, "C": 3.1975002498876555e-19},
"('facebook/opt-1.3b', 'flan')": {"A": 3.2171035336850538, "B": 0.04905071402472535, "C": 1.4665945977868665e-17},
"('facebook/opt-350m', 'flan')": {"A": 4.46817434881537, "B": 0.05761166207977083, "C": 6.466006159672916e-17},
"('facebook/opt-6.7b', 'flan')": {"A": 2.2339677920266, "B": 0.019148365251614136, "C": 2.4205570716429996e-12},
"('gpt2', 'flan')": {"A": 7.198397895874961, "B": 0.08172044535963456, "C": 7.277544347301891e-16},
"('t5-base', 'flan')": {"A": 3.543052221279221, "B": 0.05225511975662033, "C": 5.365852391533393e-14},
"('t5-small', 'flan')": {"A": 4.0524929680519675, "B": 0.05269102958942768, "C": 7.17728404772256e-18},
"('google/mt5-base', 'flan')": {"A": 4.571149127204719, "B": 0.06361719290512968, "C": 7.778644966305842e-13},
"('google/mt5-large', 'flan')": {"A": 3.492798571175553, "B": 0.05281181461988964, "C": 1.2586498420090986e-20},
"('MBZUAI/LaMini-GPT-124M', 'gigaword')": {"A": 5.923747139325469, "B": 0.1052086841512751, "C": 3.089177194303035e-17},
"('MBZUAI/LaMini-GPT-774M', 'gigaword')": {"A": 5.3506872228212465, "B": 0.10952371922324641, "C": 4.789842716916173e-16},
"('cerebras/Cerebras-GPT-256M', 'gigaword')": {"A": 5.2374256588965595, "B": 0.100041202302744, "C": 2.5291593168013106e-18},
"('cerebras/Cerebras-GPT-1.3B', 'gigaword')": {"A": 4.601802082329202, "B": 0.08937942152768066, "C": 5.975656142659641e-21},
"('facebook/bart-base', 'gigaword')": {"A": 7.594466160506642, "B": 0.13330980387218538, "C": 2.0140930917717585e-19},
"('facebook/bart-large', 'gigaword')": {"A": 7.68392069755279, "B": 0.1431066075849933, "C": 3.416340754873851e-21},
"('facebook/opt-1.3b', 'gigaword')": {"A": 5.229996705172819, "B": 0.1080530305700941, "C": 4.162452364748143e-13},
"('facebook/opt-350m', 'gigaword')": {"A": 6.499202936404031, "B": 0.1233185035135486, "C": 9.686843110254685e-16},
"('facebook/opt-6.7b', 'gigaword')": {"A": 2.177079948663568, "B": 0.014498526945722394, "C": 1.6637361013813952e-20},
"('gpt2', 'gigaword')": {"A": 6.339049550102067, "B": 0.1172559500524722, "C": 3.0102599329917646e-20},
"('t5-base', 'gigaword')": {"A": 1.8233793781091008, "B": 0.16745997024393933, "C": 0.4167409793913292},
"('t5-small', 'gigaword')": {"A": 1.7757421128794664, "B": 0.1343981132151212, "C": 0.40091836577062706},
"('google/mt5-base', 'gigaword')": {"A": 3.4396037452096033, "B": 0.03165062212756719, "C": 2.272270358483561e-21},
"('google/mt5-large', 'gigaword')": {"A": 3.5492577919728685, "B": 0.036881264614398473, "C": 3.110993196587435e-19},
"('MBZUAI/LaMini-GPT-124M', 'wikiword')": {"A": 3.7815959596442967, "B": 0.06407311340689052, "C": 2.1990783058632668e-15},
"('MBZUAI/LaMini-GPT-774M', 'wikiword')": {"A": 2.8815067635589147, "B": 0.05390255276559482, "C": 1.0505193779423451e-21},
"('cerebras/Cerebras-GPT-256M', 'wikiword')": {"A": 4.456019622760662, "B": 0.06905741513102519, "C": 1.7674523244785396e-22},
"('cerebras/Cerebras-GPT-1.3B', 'wikiword')": {"A": 3.167052219129391, "B": 0.0501048230571595, "C": 1.4069568777664217e-16},
"('facebook/bart-base', 'wikiword')": {"A": 5.861960588216526, "B": 0.1201852341432514, "C": 0.25050927467318046},
"('facebook/bart-large', 'wikiword')": {"A": 2.620750833375674, "B": 0.11520373428320353, "C": 0.781463479666753},
"('facebook/opt-1.3b', 'wikiword')": {"A": 2.349125080913662, "B": 0.0417629915758353, "C": 4.408502521317795e-06},
"('facebook/opt-350m', 'wikiword')": {"A": 3.2418388183227664, "B": 0.05546399769101606, "C": 2.38823275940463e-18},
"('facebook/opt-6.7b', 'wikiword')": {"A": 1.788104942925995, "B": 0.04220829699230767, "C": 0.26932891639769574},
"('gpt2', 'wikiword')": {"A": 3.9462672307524063, "B": 0.0681165719858611, "C": 8.567550533709646e-16},
"('t5-base', 'wikiword')": {"A": 2.259872088758908, "B": 0.04457386967601375, "C": 3.691577925559672e-22},
"('t5-small', 'wikiword')": {"A": 2.7905737414980694, "B": 0.05079953192207081, "C": 7.005925333025477e-15},
"('google/mt5-base', 'wikiword')": {"A": 4.734453344250497, "B": 0.09261262606875871, "C": 1.7979546095664325e-22},
"('google/mt5-large', 'wikiword')": {"A": 3.9262973862263366, "B": 0.0781626897617864, "C": 2.168723424360531e-14}
}
def law(input_data: list[dict[str, float]], group: str) -> list[dict[str, float]]:
"""
Predicts output variables based on input variables according to a discovered scaling law.
Args:
input_data: A list of dictionaries, where each dictionary is a single data
point containing input variable names as keys and their
corresponding values.
group: The name of the experimental group for which to make predictions.
The functional form of the law must be the same for all groups,
but the constant parameters/coefficients can differ per group.
Returns:
A list of dictionaries, corresponding to the input_data list, with each
dictionary containing the predicted output variable(s).
"""
predictions = []
# Ensure the group key matches the format in FITTED_PARAMS
# The groups in the JSON are string representations of tuples, e.g., "('model', 'dataset')"
# We need to ensure the `group` parameter is converted to this format if it's not already.
# A simple way is to re-evaluate the string representation of the tuple if group is a tuple.
# However, the input `group` is already a string according to the problem description,
# so we'll assume it matches the string keys in FITTED_PARAMS directly or needs conversion
# from a tuple to a string if the user provides tuples.
# Given the previous output of `df['group'].unique()`, the keys are indeed string representations of tuples.
# Ensure the group key is in the correct format for lookup
# If the input `group` is already a string like "('MBZUAI/LaMini-GPT-124M', 'flan')", no conversion needed.
# If it's a tuple, convert it to the string representation.
# Assuming `group` comes as a string representation of the tuple directly from the user input based on problem statement
# and previous observations, no explicit conversion from tuple to string needed here.
if group not in FITTED_PARAMS:
raise ValueError(f"Parameters not found for group: {group}. Available groups: {list(FITTED_PARAMS.keys())}")
params = FITTED_PARAMS[group]
A = params['A']
B = params['B']
C = params['C']
for data_point in input_data:
sft_data_size = data_point['sft_data_size']
if sft_data_size <= 0:
# Handle non-positive sft_data_size gracefully to avoid errors with x**-B
# For extrapolation, if B is positive, a small positive value will result in a very large loss.
# A more robust approach might be to return a very large number or NaN,
# or clamp sft_data_size to a small positive epsilon.
# For now, let's clamp it to a small positive value to allow calculation.
sft_data_size = 1e-9
predicted_sft_loss = A * (sft_data_size ** -B) + C
predictions.append({'sft_loss': predicted_sft_loss})
return predictions