import numpy as np
def law(input_data: list[dict[str, float]], group: str) -> list[dict[str, float]]:
"""
Predicts output variables based on input variables according to a discovered scaling law.
Args:
input_data: A list of dictionaries, where each dictionary is a single data
point containing input variable names as keys and their
corresponding values.
group: The name of the experimental group for which to make predictions.
The functional form of the law must be the same for all groups,
but the constant parameters/coefficients can differ per group.
Returns:
A list of dictionaries, corresponding to the input_data list, with each
dictionary containing the predicted output variable(s).
"""
# Fitted parameters for each group
# Model: L_i = A * (sum_j T_ij * p_j)^(-alpha) + C
PARAMS = {
"70M": {
"loss_domain_1": {
"A": 1.4238451169174666,
"alpha": 0.09398513070233284,
"C": 1.1250493207278296,
"T": [
1.0,
0.0028653843430912034,
0.0,
0.006560914041313572,
0.007218091947616502
]
},
"loss_domain_2": {
"A": 1.6028897145876133,
"alpha": 0.18249317423847652,
"C": 1.7098198790278034,
"T": [
0.13978044088307642,
1.0,
0.2888140021128257,
0.09238120511024654,
0.2702341876935682
]
},
"loss_domain_3": {
"A": 1.3944974988308458,
"alpha": 0.07536985492525757,
"C": 1.4540031052227687,
"T": [
0.00042972756094378624,
0.004428392340713848,
1.0,
0.0015108566548276629,
0.004618708293329572
]
},
"loss_domain_4": {
"A": 0.7156208007948365,
"alpha": 0.14819891555573603,
"C": 0.7188126729944071,
"T": [
0.004709723482930698,
0.0,
0.007159224662733469,
1.0,
0.006354517927118956
]
},
"loss_domain_5": {
"A": 1.6903071983073346,
"alpha": 0.07574314873996338,
"C": 1.7459964167978976,
"T": [
0.0,
0.1294716976819937,
0.09092332659047013,
0.014567119409274644,
1.0
]
}
},
"160M": {
"loss_domain_1": {
"A": 1.167794250478621,
"alpha": 0.09893031045785339,
"C": 1.0967302669578558,
"T": [
1.0,
0.0,
0.0,
0.004702404391421327,
0.006374420575153449
]
},
"loss_domain_2": {
"A": 1.520792165137527,
"alpha": 0.19100210139002716,
"C": 1.4717790036178258,
"T": [
0.14117673959388272,
1.0,
0.30984366873545993,
0.09323203270576641,
0.2930905074144338
]
},
"loss_domain_3": {
"A": 1.1827621056082325,
"alpha": 0.08515955574666216,
"C": 1.3810689151399624,
"T": [
0.0004496159377778205,
0.005469210248664284,
1.0,
0.00014486145067144012,
0.005546525753692289
]
},
"loss_domain_4": {
"A": 0.5950466323744031,
"alpha": 0.15657006642589474,
"C": 0.6274992666731508,
"T": [
0.0038211623746128476,
0.0,
0.004047765748103023,
1.0,
0.006932201209717277
]
},
"loss_domain_5": {
"A": 1.5374892790861532,
"alpha": 0.08558831045269366,
"C": 1.546810504279026,
"T": [
0.010196510229647623,
0.04817027475788599,
0.0331566521815975,
0.07622206331237742,
1.0
]
}
},
"305M": {
"loss_domain_1": {
"A": 1.0636633714879822,
"alpha": 0.1022580547558815,
"C": 1.0643051121057456,
"T": [
1.0,
0.0020628543899588276,
0.0,
0.005055971810713113,
0.005517181906203229
]
},
"loss_domain_2": {
"A": 1.5091813171571034,
"alpha": 0.21086460000369067,
"C": 1.317332273323408,
"T": [
0.1761437051777317,
1.0,
0.328313250063059,
0.12228845673566739,
0.3242993594399039
]
},
"loss_domain_3": {
"A": 1.2930356452032414,
"alpha": 0.06407715343277973,
"C": 1.1862706116692965,
"T": [
0.00010662729851481288,
0.0016648522472207873,
1.0,
0.0,
0.002186394372089585
]
},
"loss_domain_4": {
"A": 0.5311234622127226,
"alpha": 0.16396855241073996,
"C": 0.5956951080126597,
"T": [
0.0007484858832986562,
0.0,
0.0020986448436579866,
1.0,
0.0074776280923507365
]
},
"loss_domain_5": {
"A": 1.4576330665204935,
"alpha": 0.08684367375738934,
"C": 1.4616723046222218,
"T": [
0.0,
0.04878929729092485,
0.027998535239229693,
0.07915718776179,
1.0
]
}
},
"410M": {
"loss_domain_1": {
"A": 1.0716651256430547,
"alpha": 0.0979215135041499,
"C": 1.0023484878829527,
"T": [
1.0,
0.0,
0.0,
0.003578328625604031,
0.00520215592199023
]
},
"loss_domain_2": {
"A": 1.4082722528529894,
"alpha": 0.21260676990383326,
"C": 1.3641437773794949,
"T": [
0.16418103324714028,
1.0,
0.37481870992512745,
0.09587346431048725,
0.32578174634303375
]
},
"loss_domain_3": {
"A": 1.3149711554737487,
"alpha": 0.062094311988635076,
"C": 1.1032193950712292,
"T": [
0.0,
0.001564258864235631,
1.0,
8.56970983439856e-05,
0.0018986297882148968
]
},
"loss_domain_4": {
"A": 0.49842678668584145,
"alpha": 0.1778525216820956,
"C": 0.5822681110448701,
"T": [
0.0060231123550343905,
0.0007592880970631944,
0.0026719690465761368,
1.0,
0.0086213744689824
]
},
"loss_domain_5": {
"A": 3.233311673093248,
"alpha": 0.04148885832669186,
"C": -0.39223243148328407,
"T": [
0.01291802719802647,
0.0,
0.3021193873339115,
0.0,
1.0
]
}
}
}
if group not in PARAMS:
raise ValueError(f"Unknown group: {group}")
group_params = PARAMS[group]
predictions = []
for item in input_data:
pred_item = {}
# Extract proportions vector p
# Assuming keys are 'proportion_domain_1' to 'proportion_domain_5'
p = np.array([
item.get('proportion_domain_1', 0.0),
item.get('proportion_domain_2', 0.0),
item.get('proportion_domain_3', 0.0),
item.get('proportion_domain_4', 0.0),
item.get('proportion_domain_5', 0.0)
])
# Compute loss for each domain
for i in range(1, 6):
domain_key = f"loss_domain_{i}"
if domain_key in group_params:
params = group_params[domain_key]
A = params["A"]
alpha = params["alpha"]
C = params["C"]
T = np.array(params["T"])
# Effective proportion
p_eff = np.dot(p, T)
# Handle effectively zero
p_eff = max(p_eff, 1e-9)
loss_pred = A * (p_eff ** -alpha) + C
pred_item[domain_key] = loss_pred
predictions.append(pred_item)
return predictions