from __future__ import annotations
import math
from typing import Dict, Iterable, List, Tuple
# Optional, but provides fast array ops if available.
try:
import numpy as np # type: ignore
except Exception: # pragma: no cover
np = None # type: ignore
# Optional dependency. We fail gracefully if not present or data is missing.
try:
from datasets import load_from_disk # type: ignore
except Exception: # pragma: no cover
load_from_disk = None # type: ignore
DATA_PATH = "/app/data"
# Fitted parameters are stored here as:
# PARAMS[group] = {"L_inf": float, "A": float, "alpha": float}
PARAMS: Dict[str, Dict[str, float]] = {}
def _to_np(x: Iterable[float]):
if np is None:
# Minimal shim with list semantics when numpy is unavailable
return list(float(v) for v in x)
return np.asarray(list(float(v) for v in x), dtype=float)
def _linear_regression(x: Iterable[float], y: Iterable[float]) -> Tuple[float, float]:
"""
Fit y = b + m * x by least squares.
Returns (b, m).
"""
X = _to_np(x)
Y = _to_np(y)
if np is None:
n = len(X)
if n == 0:
return 0.0, 0.0
mean_x = sum(X) / n
mean_y = sum(Y) / n
num = sum((xi - mean_x) * (yi - mean_y) for xi, yi in zip(X, Y))
den = sum((xi - mean_x) ** 2 for xi in X)
m = 0.0 if den == 0 else (num / den)
b = mean_y - m * mean_x
return b, m
else:
x_mean = float(np.mean(X)) if X.size else 0.0
y_mean = float(np.mean(Y)) if Y.size else 0.0
num = float(np.sum((X - x_mean) * (Y - y_mean)))
den = float(np.sum((X - x_mean) ** 2))
m = 0.0 if den == 0.0 else (num / den)
b = y_mean - m * x_mean
return b, m
def _sse_loss(N, L, L_inf: float, A: float, alpha: float) -> float:
# Compute mean squared error in original space for stability/interpretability
if np is None:
preds = [L_inf + A * (max(n, 1e-12) ** (-alpha)) for n in N]
residuals = [(lp - lt) for lp, lt in zip(preds, L)]
return sum(r * r for r in residuals) / (len(residuals) or 1)
else:
N = _to_np(N)
L = _to_np(L)
preds = L_inf + A * np.power(np.maximum(N, 1e-12), -alpha)
residuals = preds - L
return float(np.mean(residuals ** 2))
def _fit_group(N_raw: Iterable[float], L_raw: Iterable[float]) -> Tuple[float, float, float]:
"""
Fit the scaling law:
L(N) = L_inf + A * N^{-alpha}
using a grid-search over L_inf and closed-form linear regression for (log A, alpha).
"""
# Clean data: positive N, finite values
N = []
L = []
for n, l in zip(N_raw, L_raw):
if n is None or l is None:
continue
try:
n_val = float(n)
l_val = float(l)
except Exception:
continue
if not (math.isfinite(n_val) and math.isfinite(l_val)):
continue
if n_val <= 0.0:
continue
N.append(n_val)
L.append(l_val)
if len(N) < 2:
# Fallback: insufficient data
L_inf = min(L) if L else 0.0
A = max((max(L) - L_inf), 1e-6) if L else 1.0
alpha = 0.5
return float(L_inf), float(A), float(alpha)
# Establish a reasonable L_inf search range:
L_min = min(L)
L_max = max(L)
if L_min <= 0:
low = L_min * 0.5
high = min(L_min * 0.99, L_min - 1e-6)
else:
low = max(0.0, 0.1 * L_min)
high = 0.99 * L_min
# If range is degenerate, expand conservatively
if not math.isfinite(low) or not math.isfinite(high) or low >= high:
low = max(0.0, L_min * 0.25)
high = 0.99 * L_min if L_min > 0 else (L_min * 0.9)
# Build candidate grid for L_inf
grid_count = 101
if np is None:
L_grid = [low + (high - low) * i / (grid_count - 1) for i in range(grid_count)]
else:
L_grid = list(np.linspace(low, high, grid_count))
best = {
"sse": float("inf"),
"L_inf": None, # type: ignore
"A": None, # type: ignore
"alpha": None, # type: ignore
}
logN = [math.log(n) for n in N]
for L0 in L_grid:
# Ensure positivity of (L - L0)
y = [l - L0 for l in L]
if any(v <= 0 for v in y):
continue
logy = [math.log(v) for v in y]
b, m = _linear_regression(logN, logy) # log(y) = b + m * logN
# Here, m = -alpha and A = exp(b)
alpha = -m
if not math.isfinite(alpha) or alpha <= 0:
continue
A = math.exp(b)
if not math.isfinite(A) or A <= 0:
continue
sse = _sse_loss(N, L, L0, A, alpha)
if sse < best["sse"]:
best.update({"sse": sse, "L_inf": L0, "A": A, "alpha": alpha})
# If grid search failed (e.g., numerical issues), fallback
if best["L_inf"] is None:
L_inf = max(0.0, 0.5 * L_min)
A = max(L_max - L_inf, 1e-6)
alpha = 0.5
return float(L_inf), float(A), float(alpha)
# Optional local refinement around the best L_inf
L0 = float(best["L_inf"])
span = max(1e-12, 0.1 * abs(L0) + 1e-6)
candidates = [L0 + d for d in ( -span, -span/2, 0.0, span/2, span )]
for Lc in candidates:
y = [l - Lc for l in L]
if any(v <= 0 for v in y):
continue
logy = [math.log(v) for v in y]
b, m = _linear_regression(logN, logy)
alpha = -m
if not math.isfinite(alpha) or alpha <= 0:
continue
A = math.exp(b)
if not math.isfinite(A) or A <= 0:
continue
sse = _sse_loss(N, L, Lc, A, alpha)
if sse < best["sse"]:
best.update({"sse": sse, "L_inf": Lc, "A": A, "alpha": alpha})
return float(best["L_inf"]), float(best["A"]), float(best["alpha"])
def _load_all_rows_from_disk(path: str):
if load_from_disk is None:
return []
try:
ds = load_from_disk(path)
except Exception:
return []
rows = []
try:
# DatasetDict (multiple splits)
values = getattr(ds, "values", None)
if callable(values):
for split in ds.values(): # type: ignore[attr-defined]
for ex in split:
rows.append(ex)
else:
# Single Dataset
for ex in ds:
rows.append(ex)
except Exception:
# As a very last resort, try iterating directly
try:
for ex in ds:
rows.append(ex)
except Exception:
return []
return rows
def _detect_group_key(example: dict) -> str | None:
candidates = (
"group",
"sft_group",
"exp_group",
"setting",
"task",
"model",
)
for k in candidates:
if k in example:
return k
return None
def _detect_size_key(example: dict) -> str | None:
candidates = (
"sft_data_size",
"data_size",
"n",
"N",
"examples",
"num_examples",
"train_examples",
)
for k in candidates:
if k in example:
return k
return None
def _detect_loss_key(example: dict) -> str | None:
candidates = (
"sft_loss",
"loss",
"final_loss",
)
for k in candidates:
if k in example:
return k
return None
def _fit_params_from_data() -> Dict[str, Dict[str, float]]:
rows = _load_all_rows_from_disk(DATA_PATH)
if not rows:
# Fallback defaults if no data
return {"GLOBAL": {"L_inf": 0.0, "A": 1.0, "alpha": 0.5}}
# Detect key names
g_key = _detect_group_key(rows[0]) or "group"
n_key = _detect_size_key(rows[0]) or "sft_data_size"
l_key = _detect_loss_key(rows[0]) or "sft_loss"
groups: Dict[str, Tuple[List[float], List[float]]] = {}
allN: List[float] = []
allL: List[float] = []
for ex in rows:
if n_key not in ex or l_key not in ex:
continue
g = str(ex.get(g_key, "GLOBAL"))
try:
n = float(ex[n_key])
l = float(ex[l_key])
except Exception:
continue
if not (math.isfinite(n) and math.isfinite(l)) or n <= 0:
continue
allN.append(n)
allL.append(l)
if g not in groups:
groups[g] = ([], [])
groups[g][0].append(n)
groups[g][1].append(l)
params: Dict[str, Dict[str, float]] = {}
# Global fit (useful fallback)
L_inf_g, A_g, alpha_g = _fit_group(allN, allL)
params["GLOBAL"] = {"L_inf": L_inf_g, "A": A_g, "alpha": alpha_g}
# Per-group fits
for g, (Ns, Ls) in groups.items():
L_inf, A, alpha = _fit_group(Ns, Ls)
params[g] = {"L_inf": L_inf, "A": A, "alpha": alpha}
return params
# Eagerly fit at import time for reproducibility and speed at inference.
try:
PARAMS = _fit_params_from_data()
except Exception:
# Robust fallback
PARAMS = {"GLOBAL": {"L_inf": 0.0, "A": 1.0, "alpha": 0.5}}
def _get_param_set(group: str) -> Dict[str, float]:
# Exact group match
if group in PARAMS:
return PARAMS[group]
# Case-insensitive match
for g in PARAMS:
if g.lower() == group.lower():
return PARAMS[g]
# Fallback to GLOBAL
return PARAMS.get("GLOBAL", {"L_inf": 0.0, "A": 1.0, "alpha": 0.5})
def _extract_size_from_input(d: dict) -> float:
# Support several synonymous keys for convenience.
for k in ("sft_data_size", "data_size", "n", "N", "examples", "num_examples", "train_examples"):
if k in d:
try:
v = float(d[k])
if math.isfinite(v) and v > 0:
return v
except Exception:
continue
# If missing/invalid, use a tiny positive to avoid division-by-zero
return 1e-12
def law(input_data: list[dict[str, float]], group: str) -> list[dict[str, float]]:
"""
Predicts output variables based on input variables according to a discovered scaling law.
Args:
input_data: A list of dictionaries, where each dictionary is a single data
point containing input variable names as keys and their
corresponding values.
group: The name of the experimental group for which to make predictions.
The functional form of the law must be the same for all groups,
but the constant parameters/coefficients can differ per group.
Returns:
A list of dictionaries, corresponding to the input_data list, with each
dictionary containing the predicted output variable(s).
Scaling law implemented:
sft_loss(N) = L_inf + A * N^(-alpha)
"""
params = _get_param_set(group)
L_inf = float(params.get("L_inf", 0.0))
A = float(params.get("A", 1.0))
alpha = float(params.get("alpha", 0.5))
outputs: List[dict[str, float]] = []
for row in input_data:
N = _extract_size_from_input(row)
# Numerical guard
N = max(float(N), 1e-12)
pred = float(L_inf + A * (N ** (-alpha)))
outputs.append({"sft_loss": pred})
return outputs