from __future__ import annotations
from typing import Dict, List
def law(input_data: List[Dict[str, float]], group: str) -> List[Dict[str, float]]:
"""
Predicts output variables based on input variables according to a discovered scaling law.
The functional form is a multiplicative power law with an asymptotic floor:
loss = A + K * params**a * tokens**b * unique_tokens**c
The exponents (a, b, c), the scale K, and the asymptote A are constant within
an experimental group, but may differ across groups. If an unknown group is
requested, this implementation falls back to the coefficients learned for
'all_data'.
Args:
input_data: A list of dictionaries with keys: 'params', 'tokens', 'unique_tokens'.
group: The experimental group name.
Returns:
A list of dictionaries with a single key 'loss' for each input point.
"""
# Coefficients fitted on the provided dataset (group 'all_data').
# Format: group -> (A, K, a, b, c)
COEFFICIENTS: Dict[str, tuple[float, float, float, float, float]] = {
# Derived from least-squares in log-space on /app/data
# A = 2.554117, K = 605300.7661134443, a = -0.22335831735443584,
# b = -0.15792362123370007, c = -0.2764050828072919
"all_data": (2.554117, 605300.7661134443, -0.22335831735443584, -0.15792362123370007, -0.2764050828072919),
}
A, K, a, b, c = COEFFICIENTS.get(group, COEFFICIENTS["all_data"])
out: List[Dict[str, float]] = []
for row in input_data:
p = float(row.get("params", 0.0))
t = float(row.get("tokens", 0.0))
u = float(row.get("unique_tokens", 0.0))
# Basic guards to avoid invalid math; domain of the law expects positives.
if p <= 0 or t <= 0 or u <= 0:
# Fall back to the asymptote if inputs are invalid or missing.
pred = float(A)
else:
pred = float(A + K * (p ** a) * (t ** b) * (u ** c))
out.append({"loss": pred})
return out