def law(input_data: list[dict[str, float]], group: str) -> list[dict[str, float]]:
"""
Predicts output variables based on input variables according to a discovered scaling law.
Args:
input_data: A list of dictionaries, where each dictionary is a single data
point containing input variable names as keys and their
corresponding values.
group: The name of the experimental group for which to make predictions.
The functional form of the law must be the same for all groups,
but the constant parameters/coefficients can differ per group.
Returns:
A list of dictionaries, corresponding to the input_data list, with each
dictionary containing the predicted output variable(s).
"""
# Coefficients for each group, domain, and proportion
# Discovered through linear regression: loss_i = intercept_i + sum_j(coeff_ij * proportion_j)
coefficients = {
"160M": {
"loss_domain_1": {
"intercept": 2.2047495302794857,
"proportions": [-0.6785457845693336, 1.1463793805820295, 0.15167833367170308, 1.025939940580135, 0.559297660014955]
},
"loss_domain_2": {
"intercept": 2.7485645127653795,
"proportions": [0.8454146604044328, -0.17669189456942216, 0.5339549342653283, 0.8879814770372856, 0.657905335627755]
},
"loss_domain_3": {
"intercept": 1.777028370562523,
"proportions": [1.6343684204048485, 1.3463909111398886, -4.1176102362247855, 1.4889565642039464, 1.4249227110386287]
},
"loss_domain_4": {
"intercept": 1.157804803455643,
"proportions": [0.5984442545273441, 0.9125282761471838, 0.5726678524110912, -1.5517735333398215, 0.6259379537098435]
},
"loss_domain_5": {
"intercept": 2.8123009748266243,
"proportions": [0.5902786179156017, 0.5892359010590805, 0.7654442820304688, 0.6814456996467663, 0.18589647417470898]
}
},
"305M": {
"loss_domain_1": {
"intercept": 2.0812993412453547,
"proportions": [-0.6681116974382604, 1.0884138560917158, 0.16066797498121796, 0.9749209723348417, 0.5254082352758425]
},
"loss_domain_2": {
"intercept": 2.6211416491050334,
"proportions": [0.8044876469222303, -0.18111733656107531, 0.528440945418323, 0.8482493615636207, 0.6210810317619353]
},
"loss_domain_3": {
"intercept": 1.6406715315205858,
"proportions": [1.6270585820746972, 1.3740527268982263, -4.2844089602142255, 1.4966639435477096, 1.4273052392141794]
},
"loss_domain_4": {
"intercept": 1.0694804673811962,
"proportions": [0.5643140257573426, 0.8424196628193863, 0.55242066083749, -1.4834642093589991, 0.5937903273259748]
},
"loss_domain_5": {
"intercept": 2.675669426576589,
"proportions": [0.5562710617961506, 0.5594325514171398, 0.7566886061299141, 0.647156326106111, 0.15612088112727496]
}
},
"410M": {
"loss_domain_1": {
"intercept": 2.026790579832494,
"proportions": [-0.6707570147435764, 1.1129214834370829, 0.0851387854718725, 0.9896066568001738, 0.5098806688669443]
},
"loss_domain_2": {
"intercept": 2.5586009139858072,
"proportions": [0.7940122885118273, -0.157381828677838, 0.45948615911537244, 0.8582702215518173, 0.6042140734846292]
},
"loss_domain_3": {
"intercept": 1.574714953134018,
"proportions": [1.6435298197556953, 1.3880256729190952, -4.3849209143151056, 1.4955237719771226, 1.4325566027972132]
},
"loss_domain_4": {
"intercept": 1.0456653717309252,
"proportions": [0.5202273851847438, 0.863550404237731, 0.5487421835368134, -1.450622661281582, 0.5637680600532177]
},
"loss_domain_5": {
"intercept": 2.611253142181258,
"proportions": [0.5424865509221992, 0.5856594143219438, 0.6774479937054709, 0.665680734573013, 0.13997844865863204]
}
},
"70M": {
"loss_domain_1": {
"intercept": 2.47466987462387,
"proportions": [-0.7074679140450373, 1.2175088434044439, 0.22352562602536008, 1.106631113925414, 0.6344722053136926]
},
"loss_domain_2": {
"intercept": 3.0158832255809034,
"proportions": [0.9147245891387732, -0.16651450748750043, 0.5733419334967286, 0.9507289447583598, 0.7436022656745435]
},
"loss_domain_3": {
"intercept": 1.9867908155493292,
"proportions": [1.7511750123143863, 1.4302318794850386, -4.279610044979821, 1.5512789318819353, 1.53371503684779]
},
"loss_domain_4": {
"intercept": 1.355256984805328,
"proportions": [0.6709376757670867, 1.0186456755589912, 0.674235515641396, -1.7195739626404851, 0.7110120804783369]
},
"loss_domain_5": {
"intercept": 3.093438385056124,
"proportions": [0.6590286880560735, 0.63342029207528, 0.805556538546447, 0.7338054152807485, 0.2616274510975776]
}
}
}
results = []
for data_point in input_data:
prediction = {}
# Get coefficients for this group
group_coeffs = coefficients[group]
# Predict each domain's loss
for domain_idx in range(1, 6):
domain_key = f"loss_domain_{domain_idx}"
domain_params = group_coeffs[domain_key]
# Calculate: intercept + sum(coeff_j * proportion_j)
intercept = domain_params["intercept"]
proportions = domain_params["proportions"]
loss = intercept
for j in range(5):
proportion_key = f"proportion_domain_{j+1}"
if proportion_key in data_point:
loss += proportions[j] * data_point[proportion_key]
prediction[domain_key] = loss
results.append(prediction)
return results