from __future__ import annotations
import json
import math
import os
from typing import Dict, List, Tuple
import numpy as np
try:
from datasets import load_from_disk, Dataset, DatasetDict # type: ignore
except Exception: # datasets may not be available in some contexts
load_from_disk = None
Dataset = None
DatasetDict = None
# Paths (as specified by the task)
_DATA_PATH = "/app/data"
_COEFFS_PATH = "/app/coefficients.json"
# Small epsilon to avoid divide-by-zero and log domain issues
_EPS = 1e-12
# Global cache of fitted coefficients per group
# Each value is a dict with keys: L_inf, C, alpha, beta, gamma
_COEFFS: Dict[str, Dict[str, float]] = {}
def _safe_log(x: np.ndarray) -> np.ndarray:
"""Numerically safe natural log."""
return np.log(np.clip(x, _EPS, None))
def _get_groups_from_dataset(ds_obj) -> Dict[str, List[Tuple[float, float, float, float]]]:
"""
Extracts and groups data from a HuggingFace dataset object.
Returns:
Mapping: group_name -> list of tuples (vocab_size, non_vocab_parameters, num_characters, unigram_normalized_loss)
"""
groups: Dict[str, List[Tuple[float, float, float, float]]] = {}
def _add_example(ex: dict):
try:
V = float(ex["vocab_size"])
Pnv = float(ex["non_vocab_parameters"])
Nch = float(ex["num_characters"])
L = float(ex["unigram_normalized_loss"])
except Exception:
return # skip rows with missing/invalid fields
# Group name (default to "ALL" if not provided)
g = ex.get("group", "ALL")
if not isinstance(g, str):
g = str(g)
groups.setdefault(g, []).append((V, Pnv, Nch, L))
# Handle both Dataset and DatasetDict
try:
from datasets import Dataset as HFDataset, DatasetDict as HFDatasetDict # type: ignore
except Exception:
HFDataset = None
HFDatasetDict = None
if HFDatasetDict is not None and isinstance(ds_obj, HFDatasetDict):
for split in ds_obj.values():
for ex in split:
_add_example(ex)
else:
# Treat as a single split dataset or a generic iterable of dicts
for ex in ds_obj:
_add_example(ex)
return groups
def _fit_group(records: List[Tuple[float, float, float, float]]) -> Dict[str, float]:
"""
Fit parameters for one group using a multiplicative power-law with a loss floor:
L_hat = L_inf + C * V^{-alpha} * Pnv^{-beta} * Nch^{-gamma}
where V = vocab_size, Pnv = non_vocab_parameters, Nch = num_characters.
We estimate L_inf via a 1D grid search and for each candidate perform
linear regression on:
log(L - L_inf) = log C - alpha log V - beta log Pnv - gamma log Nch
"""
arr = np.array(records, dtype=float)
if arr.ndim != 2 or arr.shape[1] != 4:
# Fallback defaults if data malformed
return {"L_inf": 0.0, "C": 1.0, "alpha": 0.2, "beta": 0.2, "gamma": 0.2}
V = np.clip(arr[:, 0], _EPS, None)
P = np.clip(arr[:, 1], _EPS, None)
N = np.clip(arr[:, 2], _EPS, None)
L = np.clip(arr[:, 3], _EPS, None)
# Filter to rows with all finite values
mask = np.isfinite(V) & np.isfinite(P) & np.isfinite(N) & np.isfinite(L)
V, P, N, L = V[mask], P[mask], N[mask], L[mask]
if V.size < 5:
# Not enough data; use reasonable defaults
return {"L_inf": float(np.maximum(0.0, np.min(L) * 0.5)) if L.size else 0.0,
"C": 1.0, "alpha": 0.2, "beta": 0.2, "gamma": 0.2}
min_L = float(np.min(L))
# Candidate grid for L_inf between 0 and just below min(L)
upper = max(0.0, min_L * 0.99)
if upper <= 0:
grid = np.array([0.0], dtype=float)
else:
# Dense near zero and near min(L) to stabilize the search
grid = np.unique(np.concatenate([
np.linspace(0.0, upper, num=50, dtype=float),
np.geomspace(max(_EPS, upper / 1e6), upper, num=50, dtype=float)
]))
grid = grid[(grid >= 0.0) & (grid < min_L)]
best = {
"sse": math.inf,
"L_inf": 0.0,
"C": 1.0,
"alpha": 0.2,
"beta": 0.2,
"gamma": 0.2,
}
x1 = _safe_log(V)
x2 = _safe_log(P)
x3 = _safe_log(N)
# Design matrix (with intercept) will be built once per grid element
for L_inf_cand in grid:
# Exclude points where L - L_inf <= 0
valid = L > (L_inf_cand + _EPS)
if np.count_nonzero(valid) < 4:
continue
y = _safe_log(L[valid] - L_inf_cand)
X = np.column_stack([
np.ones_like(y),
x1[valid],
x2[valid],
x3[valid],
])
# Linear least squares fit
try:
coeffs, residuals, rank, s = np.linalg.lstsq(X, y, rcond=None)
except Exception:
continue
# Compute SSE explicitly to be safe
y_hat = X @ coeffs
sse = float(np.sum((y - y_hat) ** 2))
if sse < best["sse"]:
# Map linear solution back to parameters
logC, a1, a2, a3 = coeffs.tolist()
C = float(np.exp(logC))
alpha = float(-a1)
beta = float(-a2)
gamma = float(-a3)
# Sanity constraints to avoid pathological exponents
if not (np.isfinite(C) and np.isfinite(alpha) and np.isfinite(beta) and np.isfinite(gamma)):
continue
if C <= 0:
continue
# Clip exponents to a reasonable range
alpha = float(np.clip(alpha, -4.0, 4.0))
beta = float(np.clip(beta, -4.0, 4.0))
gamma = float(np.clip(gamma, -4.0, 4.0))
best.update({
"sse": sse,
"L_inf": float(L_inf_cand),
"C": C,
"alpha": alpha,
"beta": beta,
"gamma": gamma,
})
# If grid search failed to improve (e.g., due to degenerate data), try L_inf=0 fallback
if not np.isfinite(best["sse"]) or best["sse"] == math.inf:
L_inf_cand = 0.0
valid = L > (L_inf_cand + _EPS)
if np.count_nonzero(valid) >= 4:
y = _safe_log(L[valid] - L_inf_cand)
X = np.column_stack([np.ones_like(y), x1[valid], x2[valid], x3[valid]])
coeffs, *_ = np.linalg.lstsq(X, y, rcond=None)
logC, a1, a2, a3 = coeffs.tolist()
best.update({
"sse": 0.0,
"L_inf": 0.0,
"C": float(np.exp(logC)),
"alpha": float(-a1),
"beta": float(-a2),
"gamma": float(-a3),
})
else:
# Last resort defaults
best.update({
"sse": 0.0,
"L_inf": float(np.maximum(0.0, min_L * 0.5)),
"C": 1.0,
"alpha": 0.2,
"beta": 0.2,
"gamma": 0.2,
})
# Drop SSE from output
return {k: float(v) for k, v in best.items() if k != "sse"}
def _fit_all_groups() -> Dict[str, Dict[str, float]]:
"""
Load the dataset from disk and fit coefficients per experimental group.
Also fits an 'ALL' aggregate group as a fallback.
"""
coeffs: Dict[str, Dict[str, float]] = {}
if load_from_disk is None:
return coeffs
if not os.path.isdir(_DATA_PATH):
return coeffs
try:
ds_obj = load_from_disk(_DATA_PATH)
except Exception:
return coeffs
groups = _get_groups_from_dataset(ds_obj)
# Fit per group
for g, recs in groups.items():
if recs:
coeffs[g] = _fit_group(recs)
# Also fit ALL (aggregate) if not already present
if "ALL" not in coeffs:
all_recs: List[Tuple[float, float, float, float]] = []
for recs in groups.values():
all_recs.extend(recs)
if all_recs:
coeffs["ALL"] = _fit_group(all_recs)
# Persist for transparency and reproducibility
try:
with open(_COEFFS_PATH, "w", encoding="utf-8") as f:
json.dump(coeffs, f, indent=2, sort_keys=True)
except Exception:
pass
return coeffs
def _load_or_fit_coeffs() -> Dict[str, Dict[str, float]]:
"""
Load coefficients from JSON if available; otherwise fit from the dataset.
"""
# Try to load precomputed coefficients
if os.path.isfile(_COEFFS_PATH):
try:
with open(_COEFFS_PATH, "r", encoding="utf-8") as f:
data = json.load(f)
# Ensure floats
out: Dict[str, Dict[str, float]] = {}
for g, d in data.items():
out[g] = {
"L_inf": float(d["L_inf"]),
"C": float(d["C"]),
"alpha": float(d["alpha"]),
"beta": float(d["beta"]),
"gamma": float(d["gamma"]),
}
return out
except Exception:
pass
# Otherwise fit now
return _fit_all_groups()
# Initialize coefficients at import time for immediate availability
_COEFFS = _load_or_fit_coeffs()
def _predict_one(row: Dict[str, float], coefs: Dict[str, float]) -> float:
"""Compute prediction for one input row given fitted coefficients."""
V = float(row.get("vocab_size", 0.0))
Pnv = float(row.get("non_vocab_parameters", 0.0))
Nch = float(row.get("num_characters", 0.0))
# Safety clamps
V = V if np.isfinite(V) and V > 0 else _EPS
Pnv = Pnv if np.isfinite(Pnv) and Pnv > 0 else _EPS
Nch = Nch if np.isfinite(Nch) and Nch > 0 else _EPS
L_inf = float(coefs.get("L_inf", 0.0))
C = float(coefs.get("C", 1.0))
alpha = float(coefs.get("alpha", 0.2))
beta = float(coefs.get("beta", 0.2))
gamma = float(coefs.get("gamma", 0.2))
# L_hat = L_inf + C * V^{-alpha} * Pnv^{-beta} * Nch^{-gamma}
try:
term = C * (V ** (-alpha)) * (Pnv ** (-beta)) * (Nch ** (-gamma))
pred = L_inf + term
except Exception:
pred = L_inf + C # worst-case fallback
# Ensure non-negative prediction
return float(max(0.0, pred))
def law(input_data: List[Dict[str, float]], group: str) -> List[Dict[str, float]]:
"""
Predicts output variables based on input variables according to a discovered scaling law.
Args:
input_data: A list of dictionaries, where each dictionary is a single data
point containing input variable names as keys and their
corresponding values. Required keys:
- 'vocab_size'
- 'non_vocab_parameters'
- 'num_characters'
group: The name of the experimental group for which to make predictions.
The functional form is identical across groups; parameters differ.
Returns:
A list of dictionaries, one per input item, each containing:
{'unigram_normalized_loss': predicted_value}
"""
# Choose coefficients for the requested group, with fallbacks
coefs = _COEFFS.get(group)
if coefs is None:
coefs = _COEFFS.get("ALL")
if coefs is None:
# Final hardcoded fallback if fitting/loading failed
coefs = {"L_inf": 0.0, "C": 1.0, "alpha": 0.2, "beta": 0.2, "gamma": 0.2}
outputs: List[Dict[str, float]] = []
for row in input_data:
y = _predict_one(row, coefs)
outputs.append({"unigram_normalized_loss": y})
return outputs