from __future__ import annotations
from typing import List, Dict
import math
# Embedded fitted parameters per group and domain.
# Each domain uses either a power form: y = a + k * p**c
# or a log form: y = a + b * log(p)
PARAMS = {'160M': {'1': {'a': 3.358131379887285,
'c': 0.15000000000000036,
'form': 'power',
'k': -1.1267042149203725,
'r2': 0.9941862658661204},
'2': {'a': 3.318780734454316,
'c': -3.0,
'form': 'power',
'k': 2.2734784837905097e-05,
'r2': 0.06062614424441071},
'3': {'a': 2.931566771454378, 'c': 3.0, 'form': 'power', 'k': -174.62340395796082, 'r2': 0.7061503010752743},
'4': {'a': 3.8716421496206315,
'c': 0.050000000000000266,
'form': 'power',
'k': -2.6874577135715296,
'r2': 0.9762826201636258},
'5': {'a': 3.876397587476255,
'c': 0.20000000000000018,
'form': 'power',
'k': -0.8015051387810923,
'r2': 0.995953951272856}},
'305M': {'1': {'a': 4.832122416331462,
'c': 0.050000000000000266,
'form': 'power',
'k': -2.721490588848711,
'r2': 0.9923678438528458},
'2': {'a': 3.1584886677226094,
'c': -2.75,
'form': 'power',
'k': 4.513541219029818e-05,
'r2': 0.066949255913345},
'3': {'a': 2.7848933137166916, 'c': 3.0, 'form': 'power', 'k': -166.12136907423988, 'r2': 0.6853980994244002},
'4': {'a': -0.9540330670713865,
'c': -0.04999999999999982,
'form': 'power',
'k': 2.067152378143916,
'r2': 0.9773900180368574},
'5': {'a': 5.599492209154898,
'c': 0.050000000000000266,
'form': 'power',
'k': -2.6836119274365684,
'r2': 0.9937674933260627}},
'410M': {'1': {'a': 4.670978316018143,
'c': 0.050000000000000266,
'form': 'power',
'k': -2.613131878961905,
'r2': 0.9885683848277665},
'2': {'a': 3.091279676924703,
'c': -3.0,
'form': 'power',
'k': 1.7491883130243175e-05,
'r2': 0.03333320137761886},
'3': {'a': 2.7201830236984295, 'c': 3.0, 'form': 'power', 'k': -168.02135230322037, 'r2': 0.6489941549529987},
'4': {'a': 0.5427391497008389,
'c': -0.1499999999999999,
'form': 'power',
'k': 0.5526416189568344,
'r2': 0.9732551074881751},
'5': {'a': 3.634036329340057,
'c': 0.20000000000000018,
'form': 'power',
'k': -0.8017750604189218,
'r2': 0.9939499265653914}},
'70M': {'1': {'a': 4.297380991024045,
'c': 0.10000000000000009,
'form': 'power',
'k': -1.775827036691083,
'r2': 0.9952894285760332},
'2': {'a': 3.646005195575224,
'c': -2.6,
'form': 'power',
'k': 8.526182739288986e-05,
'r2': 0.12259002392235419},
'3': {'a': 3.2284514219048646, 'c': 3.0, 'form': 'power', 'k': -180.79876453488362, 'r2': 0.6915505904622448},
'4': {'a': 1.1444240089156006, 'c': -0.25, 'form': 'power', 'k': 0.3211850398410141, 'r2': 0.9856867978166857},
'5': {'a': 2.3099069373584133,
'c': -0.09999999999999964,
'form': 'power',
'k': 1.1287948701911399,
'r2': 0.9930440593027846}}}
# Default fallback params (averaged) if group not present
FALLBACK = {'1': {'a': 4.289653275815233, 'c': 0.08750000000000024, 'form': 'power', 'k': -2.059288429855518},
'2': {'a': 3.3036385686692133, 'c': -2.8375, 'form': 'power', 'k': 4.265597688783408e-05},
'3': {'a': 2.9162736326935907, 'c': 3.0, 'form': 'power', 'k': -172.39122246757617},
'4': {'a': 1.151193060291421, 'c': -0.09999999999999987, 'form': 'power', 'k': 0.06338033084255867},
'5': {'a': 3.8549582658324058, 'c': 0.08750000000000024, 'form': 'power', 'k': -0.7895243141113606}}
DOMAINS = [1,2,3,4,5]
def _predict_domain(p: float, spec: dict) -> float:
# Guard against non-positive proportions for log and power
p = max(p, 1e-12)
form = spec.get('form')
if form == 'power':
return spec['a'] + spec['k'] * (p ** spec['c'])
elif form == 'log':
return spec['a'] + spec['b'] * math.log(p)
else:
# Should not happen; fall back to identity-like
return float('nan')
def law(input_data: list[dict[str, float]], group: str) -> list[dict[str, float]]:
"""
Predicts output variables based on input variables according to a discovered scaling law.
Args:
input_data: list of dicts with keys proportion_domain_1..5
group: group name. Functional form is same across groups (per-domain chosen from {power, log}); coefficients differ per group.
Returns:
list of dicts with keys loss_domain_1..5
"""
# Select group-specific params or fallback
spec = PARAMS.get(group, FALLBACK)
outputs = []
for row in input_data:
out = {}
for i in DOMAINS:
p = float(row.get(f'proportion_domain_{i}', 0.0))
out[f'loss_domain_{i}'] = _predict_domain(p, spec[str(i)])
outputs.append(out)
return outputs