from __future__ import annotations
from typing import Dict, Iterable, List, Tuple
import math
# We attempt to learn group-specific coefficients from /app/data at import time.
# The functional form is shared across groups:
# loss ≈ L0_g + s_g * (num_params ** a_g) * (parallel_size ** b_g)
#
# Where:
# - L0_g is an irreducible loss floor for group g
# - s_g is a scale factor
# - a_g < 0 captures improvement with model size
# - b_g < 0 captures improvement with the degree of parallelism (ensemble/aggregation)
#
# Coefficients are estimated by choosing L0 via a small grid search and
# fitting log(loss - L0) = log(s) + a*log(num_params) + b*log(parallel_size)
# with ordinary least squares. If the dataset is unavailable, we fall back to
# conservative defaults.
# Global, group-keyed coefficients: group -> (L0, s, a, b)
_COEFFS: Dict[str, Tuple[float, float, float, float]] = {}
# Fallback/global coefficients across all groups
_GLOBAL_COEFFS: Tuple[float, float, float, float] | None = None
_EPS = 1e-12
_DATA_PATH = "/app/data"
def _safe_log(x: Iterable[float]) -> List[float]:
return [math.log(max(v, _EPS)) for v in x]
def _lstsq(X: List[List[float]], y: List[float]) -> Tuple[List[float], float]:
"""
Minimal OLS using normal equations with 3 features (intercept, x1, x2).
Returns (beta, sse) where beta = [b0, b1, b2].
"""
# Build normal equations: (X^T X) beta = X^T y
# X: n x 3
n = len(y)
if n == 0:
return [0.0, 0.0, 0.0], float("inf")
s00 = s01 = s02 = s11 = s12 = s22 = 0.0
t0 = t1 = t2 = 0.0
for i in range(n):
xi0, xi1, xi2 = X[i]
yi = y[i]
s00 += xi0 * xi0
s01 += xi0 * xi1
s02 += xi0 * xi2
s11 += xi1 * xi1
s12 += xi1 * xi2
s22 += xi2 * xi2
t0 += xi0 * yi
t1 += xi1 * yi
t2 += xi2 * yi
# Solve 3x3 system via Cramer's rule for robustness without numpy
# Matrix:
# [s00 s01 s02] [b0] = [t0]
# [s01 s11 s12] [b1] [t1]
# [s02 s12 s22] [b2] [t2]
def det3(a00, a01, a02, a10, a11, a12, a20, a21, a22) -> float:
return (
a00 * (a11 * a22 - a12 * a21)
- a01 * (a10 * a22 - a12 * a20)
+ a02 * (a10 * a21 - a11 * a20)
)
D = det3(s00, s01, s02, s01, s11, s12, s02, s12, s22)
if abs(D) < 1e-18:
# Degenerate; return zeros and high SSE
return [0.0, 0.0, 0.0], float("inf")
D0 = det3(t0, s01, s02, t1, s11, s12, t2, s12, s22)
D1 = det3(s00, t0, s02, s01, t1, s12, s02, t2, s22)
D2 = det3(s00, s01, t0, s01, s11, t1, s02, s12, t2)
b0, b1, b2 = D0 / D, D1 / D, D2 / D
# Compute SSE in original (linear) space after back-transform
sse = 0.0
for i in range(n):
# Back-transform: z = X beta => pred_log = z => pred = exp(z)
pred_log = b0 * X[i][0] + b1 * X[i][1] + b2 * X[i][2]
pred = math.exp(pred_log)
# The caller accounts for L0 outside
# Here we return SSE of log-fit as diagnostic; linear SSE computed by caller.
# For stability, return SSE in log space to compare fits consistently.
e = y[i] - pred_log
sse += e * e
return [b0, b1, b2], sse
def _fit_group(
y: List[float], n_params: List[float], p_size: List[float]
) -> Tuple[float, float, float, float]:
"""
Fit parameters (L0, s, a, b) for one group using grid search over L0 and OLS in log-space.
"""
# Sanity: ensure strictly positive features
n_params = [max(v, _EPS) for v in n_params]
p_size = [max(v, _EPS) for v in p_size]
y = [float(v) for v in y]
y_min = min(y)
y_max = max(y)
if not math.isfinite(y_min) or not math.isfinite(y_max):
return (0.0, 1.0, -0.2, -0.5)
# Grid L0 below the minimum observed loss
span = max(y_max - y_min, 1e-6)
# 41 candidates from (y_min - 0.5*span) up to (y_min - 1e-6)
grid = [
(y_min - 0.5 * span) + i * (0.5 * span - 1e-6) / 40.0 for i in range(41)
]
best = None # (lin_sse, L0, b0, b1, b2)
x1 = _safe_log(n_params)
x2 = _safe_log(p_size)
for L0 in grid:
# Ensure y - L0 > 0
diff = [max(val - L0, _EPS) for val in y]
# Prepare OLS in log space: log(diff) = b0*1 + b1*log(n) + b2*log(p)
z = [math.log(d) for d in diff]
X = [[1.0, x1[i], x2[i]] for i in range(len(z))]
beta, _ = _lstsq(X, z)
b0, b1, b2 = beta
# Evaluate SSE in original space
sse = 0.0
for i in range(len(y)):
pred = L0 + math.exp(b0 + b1 * x1[i] + b2 * x2[i])
e = y[i] - pred
sse += e * e
if (best is None) or (sse < best[0]):
best = (sse, L0, b0, b1, b2)
if best is None:
return (0.0, 1.0, -0.2, -0.5)
_, L0, b0, a, b = best
s = math.exp(b0)
return (L0, s, a, b)
def _attempt_learn_coeffs() -> None:
global _COEFFS, _GLOBAL_COEFFS
try:
from datasets import load_from_disk, Dataset, DatasetDict, concatenate_datasets # type: ignore
except Exception:
# Datasets library is unavailable; use defaults
_COEFFS = {}
_GLOBAL_COEFFS = (0.0, 1.0, -0.25, -0.5)
return
try:
ds = load_from_disk(_DATA_PATH)
except Exception:
# Dataset not present; defaults
_COEFFS = {}
_GLOBAL_COEFFS = (0.0, 1.0, -0.25, -0.5)
return
# Flatten to a single dataset
if isinstance(ds, DatasetDict):
parts = [v for k, v in ds.items()]
try:
flat = concatenate_datasets(parts)
except Exception:
# Fallback: use the first split
flat = parts[0]
else:
flat = ds # type: ignore[assignment]
# Determine group field
cand_group_fields = ["group", "group_name", "dataset", "split"]
group_field = None
for k in cand_group_fields:
if k in flat.column_names:
group_field = k
break
# Required fields
required = ["num_params", "parallel_size", "loss"]
for r in required:
if r not in flat.column_names:
# Can't fit; leave defaults
_COEFFS = {}
_GLOBAL_COEFFS = (0.0, 1.0, -0.25, -0.5)
return
# Collect per-group data
by_group: Dict[str, Dict[str, List[float]]] = {}
for ex in flat:
g = str(ex[group_field]) if group_field is not None else "default"
d = by_group.setdefault(g, {"y": [], "n": [], "p": []})
try:
n = float(ex["num_params"])
p = float(ex["parallel_size"])
y = float(ex["loss"])
except Exception:
# Skip malformed rows
continue
if not (math.isfinite(n) and math.isfinite(p) and math.isfinite(y)):
continue
d["y"].append(y)
d["n"].append(n)
d["p"].append(p)
# Fit global coefficients
all_y: List[float] = []
all_n: List[float] = []
all_p: List[float] = []
for g, d in by_group.items():
all_y.extend(d["y"])
all_n.extend(d["n"])
all_p.extend(d["p"])
if len(all_y) >= 3:
_GLOBAL_COEFFS = _fit_group(all_y, all_n, all_p)
else:
_GLOBAL_COEFFS = (0.0, 1.0, -0.25, -0.5)
# Fit each group
coeffs: Dict[str, Tuple[float, float, float, float]] = {}
for g, d in by_group.items():
if len(d["y"]) >= 3:
coeffs[g] = _fit_group(d["y"], d["n"], d["p"])
else:
coeffs[g] = _GLOBAL_COEFFS # fallback
_COEFFS = coeffs
_attempt_learn_coeffs()
def _predict_one(
n_params: float, p_size: float, coeffs: Tuple[float, float, float, float]
) -> float:
n_params = max(float(n_params), _EPS)
p_size = max(float(p_size), _EPS)
L0, s, a, b = coeffs
return L0 + s * (n_params ** a) * (p_size ** b)
def _coeffs_for_group(group: str) -> Tuple[float, float, float, float]:
if group in _COEFFS:
return _COEFFS[group]
if _GLOBAL_COEFFS is not None:
return _GLOBAL_COEFFS
# Ultimate fallback
return (0.0, 1.0, -0.25, -0.5)
def law(input_data: list[dict[str, float]], group: str) -> list[dict[str, float]]:
"""
Predicts output variables based on input variables according to a discovered scaling law.
Args:
input_data: A list of dictionaries, where each dictionary is a single data
point containing input variable names as keys and their
corresponding values.
group: The name of the experimental group for which to make predictions.
The functional form of the law is shared, coefficients vary per group.
Returns:
A list of dictionaries, corresponding to the input_data list, with each
dictionary containing the predicted output variable(s).
"""
coeffs = _coeffs_for_group(group)
out: List[Dict[str, float]] = []
for row in input_data:
n = float(row.get("num_params", 0.0))
p = float(row.get("parallel_size", 0.0))
pred = _predict_one(n, p, coeffs)
out.append({"loss": float(pred)})
return out
def _write_explain(path: str = "/app/explain.md") -> None:
"""
Utility to write a human-readable explanation file with fitted parameters.
"""
lines: List[str] = []
lines.append("# Parallel Scaling Law for Language Modeling Loss")
lines.append("")
lines.append("We model the final loss as a shared functional form across groups:")
lines.append("")
lines.append("loss_hat = L0_g + s_g * num_params^{a_g} * parallel_size^{b_g}")
lines.append("")
lines.append("Interpretation:")
lines.append("- L0_g: irreducible loss floor for group g")
lines.append("- a_g < 0: larger models reduce loss via a power law")
lines.append("- b_g < 0: aggregating parallel outputs reduces loss (akin to ensembling)")
lines.append("")
lines.append("Fitting procedure:")
lines.append("- Grid search over L0 below min(loss) for numerical stability.")
lines.append("- For each L0, fit log(loss - L0) = log(s) + a*log(num_params) + b*log(parallel_size)")
lines.append("- Choose the L0 and coefficients minimizing squared error in the original space.")
lines.append("")
if _GLOBAL_COEFFS is not None:
L0, s, a, b = _GLOBAL_COEFFS
lines.append("Global coefficients (all groups pooled):")
lines.append(f"- L0 = {L0:.6g}, s = {s:.6g}, a = {a:.6g}, b = {b:.6g}")
lines.append("")
if _COEFFS:
lines.append("Per-group fitted coefficients:")
for g, (L0, s, a, b) in sorted(_COEFFS.items(), key=lambda kv: str(kv[0])):
lines.append(f"- {g}: L0 = {L0:.6g}, s = {s:.6g}, a = {a:.6g}, b = {b:.6g}")
lines.append("")
lines.append("Notes:")
lines.append("- The same functional form is used for every group; only the constants differ.")
lines.append("- The exponent b often trends near -0.5, consistent with variance reduction from aggregating parallel outputs.")
lines.append("- The model is intentionally simple to support extrapolation.")
content = "\n".join(lines) + "\n"
try:
with open(path, "w", encoding="utf-8") as f:
f.write(content)
except Exception:
# Best-effort; ignore write errors
pass