from __future__ import annotations
import math
from typing import Dict, List
# Global, group-invariant exponents discovered from model selection
# loss ≈ d_g + a_g * P^{-ALPHA} + b_g * T^{-BETA} + c_g * U^{-GAMMA} + e_g * (T/U)^{-DELTA}
ALPHA = 0.5905144353732534
BETA = 0.6010038076798887
GAMMA = 0.04387501406139306
DELTA = 0.9108617183578192
# Default coefficients, primarily for fallback if training data for a group is unavailable.
# These were fit on the provided dataset's single group ("all_data").
_DEFAULT_COEFS: Dict[str, list[float]] = {
"all_data": [
-1.48205727e+00, # d
1.91025244e+04, # a
1.87355582e+05, # b
1.08182446e+01, # c
3.85869099e-01, # e
]
}
# Cache for coefficients per group once fit from disk data
_COEF_CACHE: Dict[str, list[float]] = {}
def _fit_group_from_disk(group: str) -> list[float] | None:
try:
from datasets import load_from_disk # type: ignore
import numpy as np # type: ignore
except Exception:
return None
try:
ds = load_from_disk("/app/data")
if hasattr(ds, "keys"):
d = ds.get("train", next(iter(ds.values())))
else:
d = ds
# Filter to requested group if present
if "group" in d.column_names:
df = d.to_pandas()
if group in set(df["group"].unique()):
gdf = df[df["group"] == group]
else:
# Fallback: use all rows to provide a generic estimate
gdf = df
else:
gdf = d.to_pandas()
P = gdf["params"].to_numpy(dtype=float)
T = gdf["tokens"].to_numpy(dtype=float)
U = gdf["unique_tokens"].to_numpy(dtype=float)
y = gdf["loss"].to_numpy(dtype=float)
u_safe = np.maximum(U, 1.0)
X = np.stack([
np.ones_like(P),
P ** (-ALPHA),
T ** (-BETA),
U ** (-GAMMA),
(T / u_safe) ** (-DELTA),
], axis=1)
coef, *_ = np.linalg.lstsq(X, y, rcond=None)
return [float(v) for v in coef]
except Exception:
return None
def _get_coefs(group: str) -> list[float]:
if group in _COEF_CACHE:
return _COEF_CACHE[group]
coef = _fit_group_from_disk(group)
if coef is None:
# Fallbacks
if group in _DEFAULT_COEFS:
coef = _DEFAULT_COEFS[group]
elif "all_data" in _DEFAULT_COEFS:
coef = _DEFAULT_COEFS["all_data"]
else:
# Conservative neutral defaults
coef = [4.0, 0.0, 0.0, 0.0, 0.0]
_COEF_CACHE[group] = coef
return coef
def _safe_pow(x: float, power: float) -> float:
# Avoid infinities for non-positive bases; clamp to small positive epsilon
if not (x > 0.0):
x = 1e-18
try:
return x ** power
except Exception:
# Fallback for extreme cases
return math.exp(power * math.log(max(x, 1e-18)))
def law(input_data: List[Dict[str, float]], group: str) -> List[Dict[str, float]]:
"""
Predicts output variables based on input variables according to a discovered scaling law.
Args:
input_data: A list of dictionaries, where each dictionary is a single data
point containing input variable names as keys and their
corresponding values. Expected keys: 'params', 'tokens', 'unique_tokens'.
group: The name of the experimental group for which to make predictions.
The functional form of the law is the same for all groups, but
the coefficients differ per group.
Returns:
A list of dictionaries, corresponding to the input_data list, with each
dictionary containing the predicted output variable(s) under key 'loss'.
"""
if not input_data:
return []
d0, a, b, c, e = _get_coefs(group)
preds: List[Dict[str, float]] = []
for item in input_data:
P = float(item.get("params", 0.0))
T = float(item.get("tokens", 0.0))
U = float(item.get("unique_tokens", 1.0))
U = max(U, 1.0)
x1 = _safe_pow(P, -ALPHA)
x2 = _safe_pow(T, -BETA)
x3 = _safe_pow(U, -GAMMA)
x4 = _safe_pow(T / U if U > 0 else 0.0, -DELTA)
yhat = d0 + a * x1 + b * x2 + c * x3 + e * x4
preds.append({"loss": float(yhat)})
return preds