"""
Scaling law predictor for MoE architectures.
This module exposes a single function:
law(input_data: list[dict[str, float]], group: str) -> list[dict[str, float]]
It implements the discovered functional form:
loss_validation ≈ L_inf[group]
+ A[group] * dense_parameter_count^(-alpha[group])
+ B[group] * num_experts^(-beta[group])
Per-group coefficients are loaded from /app/params.json. If that file does not
exist, the module will fit parameters from the dataset at /app/data, persist
them to /app/params.json, and also write a detailed explanation to /app/explain.md.
You can force (re)fitting and regenerating /app/explain.md by running:
python /app/law.py
"""
from __future__ import annotations
import json
import math
import os
from dataclasses import dataclass
from datetime import datetime
from typing import Dict, List, Tuple
import numpy as np
PARAMS_PATH = "/app/params.json"
DATASET_PATH = "/app/data"
EXPLAIN_PATH = "/app/explain.md"
@dataclass
class Coeffs:
L_inf: float
A: float
alpha: float
B: float
beta: float
def predict(self, num_experts: float, dense_parameter_count: float) -> float:
# Numerical safety: clamp to tiny positive to avoid 0**negative.
e = max(float(num_experts), 1e-12)
p = max(float(dense_parameter_count), 1e-12)
return float(self.L_inf + self.A * p ** (-self.alpha) + self.B * e ** (-self.beta))
def _load_params() -> Dict[str, Coeffs] | None:
if not os.path.exists(PARAMS_PATH):
return None
with open(PARAMS_PATH, "r", encoding="utf-8") as f:
raw = json.load(f)
coeffs: Dict[str, Coeffs] = {}
for g, c in raw.get("coefficients", {}).items():
coeffs[g] = Coeffs(
L_inf=float(c["L_inf"]),
A=float(c["A"]),
alpha=float(c["alpha"]),
B=float(c["B"]),
beta=float(c["beta"]),
)
return coeffs
def _save_params_and_explain(
coeffs: Dict[str, Coeffs],
fit_meta: Dict[str, Dict[str, float]],
columns: Dict[str, str],
) -> None:
# Save params.json
payload = {
"formula": "loss = L_inf + A * dense_parameter_count^(-alpha) + B * num_experts^(-beta)",
"fitted_on": datetime.utcnow().isoformat() + "Z",
"dataset_path": DATASET_PATH,
"columns": columns,
"coefficients": {
g: {
"L_inf": c.L_inf,
"A": c.A,
"alpha": c.alpha,
"B": c.B,
"beta": c.beta,
**({"r2": fit_meta[g]["r2"], "mse": fit_meta[g]["mse"]} if g in fit_meta else {}),
}
for g, c in coeffs.items()
},
}
with open(PARAMS_PATH, "w", encoding="utf-8") as f:
json.dump(payload, f, indent=2, sort_keys=True)
# Write explain.md with details and fitted parameters
lines: List[str] = []
lines.append("# Discovered Scaling Law for MoE Validation Loss")
lines.append("")
lines.append("This document describes the fitted scaling law relating Mixture-of-Experts (MoE) architecture choices to the final validation loss.")
lines.append("")
lines.append("## Functional form")
lines.append("")
lines.append("We model the validation loss as an additive, saturating power-law in the dense (non-expert) parameters and the number of experts:")
lines.append("")
lines.append("loss_validation = L_inf + A * dense_parameter_count^(-alpha) + B * num_experts^(-beta)")
lines.append("")
lines.append("- L_inf: irreducible loss floor (as capacity → ∞).")
lines.append("- A, alpha: magnitude and decay exponent for dense parameters.")
lines.append("- B, beta: magnitude and decay exponent for the number of experts.")
lines.append("")
lines.append("Rationale: In MoE models, increasing experts primarily expands conditional capacity, while dense parameters govern shared representation power. Empirically, both exhibit diminishing returns well-captured by power-laws; the additive form separates their contributions around a floor L_inf.")
lines.append("")
lines.append("## Fitting methodology")
lines.append("")
lines.append("For each experimental group, we:")
lines.append("- Performed a grid search over exponents alpha and beta on [0.05, 2.0].")
lines.append("- For each (alpha, beta), solved a linear least-squares fit for (L_inf, A, B) in y ≈ c + A * P^(-alpha) + B * E^(-beta).")
lines.append("- Selected the (alpha, beta) that minimized MSE, and reported the corresponding (L_inf, A, B).")
lines.append("")
lines.append("This approach avoids non-convex optimization while capturing the main curvature via exponents.")
lines.append("")
lines.append("## Fitted parameters by group")
lines.append("")
lines.append("| Group | L_inf | A | alpha | B | beta | R^2 | MSE |")
lines.append("|---|---:|---:|---:|---:|---:|---:|---:|")
for g, c in coeffs.items():
r2 = fit_meta.get(g, {}).get("r2", float("nan"))
mse = fit_meta.get(g, {}).get("mse", float("nan"))
lines.append(
f"| {g} | {c.L_inf:.6g} | {c.A:.6g} | {c.alpha:.6g} | {c.B:.6g} | {c.beta:.6g} | {r2:.4f} | {mse:.6g} |"
)
lines.append("")
lines.append("## Column mapping")
lines.append("")
lines.append(f"- num_experts: `{columns.get('num_experts')}`")
lines.append(f"- dense_parameter_count: `{columns.get('dense_parameter_count')}`")
lines.append(f"- loss_validation: `{columns.get('loss_validation')}`")
lines.append(f"- group: `{columns.get('group')}`")
lines.append("")
lines.append("To reproduce or update these values, run:")
lines.append("")
lines.append("```bash")
lines.append("python /app/law.py")
lines.append("```")
with open(EXPLAIN_PATH, "w", encoding="utf-8") as f:
f.write("\n".join(lines))
def _choose_split(ds) -> Tuple[object, str]:
# HuggingFace datasets: handle DatasetDict or Dataset
try:
from datasets import DatasetDict # type: ignore
except Exception:
DatasetDict = None # type: ignore
if DatasetDict is not None and isinstance(ds, DatasetDict):
if "train" in ds:
return ds["train"], "train"
# Fallback to the first available split
first_key = next(iter(ds.keys()))
return ds[first_key], first_key
return ds, "dataset"
def _detect_columns(cols: List[str]) -> Dict[str, str]:
# Try a set of likely column names for each required variable.
def pick(cands: List[str]) -> str | None:
for c in cands:
if c in cols:
return c
return None
experts = pick(["num_experts", "experts", "n_experts", "moe_num_experts"])
dense = pick(
[
"dense_parameter_count",
"dense_params",
"dense_parameters",
"non_expert_parameters",
"dense_parameter_count_total",
]
)
loss = pick(["loss_validation", "val_loss", "validation_loss", "valid_loss"])
group = pick(["group", "group_name", "group_id", "dataset_group", "task"])
missing = [name for name, v in [("num_experts", experts), ("dense_parameter_count", dense), ("loss_validation", loss)] if v is None]
if missing:
raise RuntimeError(f"Required columns not found in dataset: {', '.join(missing)}. Available: {cols}")
return {
"num_experts": experts, # type: ignore[arg-type]
"dense_parameter_count": dense, # type: ignore[arg-type]
"loss_validation": loss, # type: ignore[arg-type]
"group": group or "__ALL__",
}
def _to_numpy_column(dataset, key: str) -> np.ndarray:
# Convert a dataset column to numpy array of float
data = dataset.to_dict()
col = data[key]
return np.asarray(col, dtype=float)
def _fit_group(E: np.ndarray, P: np.ndarray, y: np.ndarray) -> Tuple[Coeffs, float, float]:
# Grid over exponents, solve linear least squares for (c, A, B) for each.
alphas = np.linspace(0.05, 2.0, 40)
betas = np.linspace(0.05, 2.0, 40)
best_loss = math.inf
best_params = None # type: ignore
# Precompute logs if needed; here we directly compute power features.
for alpha in alphas:
fP = np.power(np.clip(P, 1e-12, None), -alpha)
# Early compute to avoid repetition
for beta in betas:
fE = np.power(np.clip(E, 1e-12, None), -beta)
# Design matrix with intercept
F = np.stack([np.ones_like(fP), fP, fE], axis=1)
# Solve least squares: y ≈ F @ theta, theta = [c, A, B]
theta, residuals, rank, s = np.linalg.lstsq(F, y, rcond=None)
yhat = F @ theta
mse = float(np.mean((y - yhat) ** 2))
if mse < best_loss:
best_loss = mse
best_params = (theta[0], theta[1], float(alpha), theta[2], float(beta))
assert best_params is not None
c, A, alpha, B, beta = [float(v) for v in best_params]
# Compute R^2
yhat = (c + A * np.power(np.clip(P, 1e-12, None), -alpha) + B * np.power(np.clip(E, 1e-12, None), -beta))
ss_res = float(np.sum((y - yhat) ** 2))
ss_tot = float(np.sum((y - np.mean(y)) ** 2))
r2 = 1.0 - (ss_res / ss_tot if ss_tot > 0 else 0.0)
mse = float(np.mean((y - yhat) ** 2))
return Coeffs(L_inf=c, A=A, alpha=alpha, B=B, beta=beta), r2, mse
def _fit_from_dataset() -> Tuple[Dict[str, Coeffs], Dict[str, Dict[str, float]], Dict[str, str]]:
try:
from datasets import load_from_disk # type: ignore
except Exception as e:
raise RuntimeError(
"The 'datasets' package is required to fit parameters. "
"Install it via: pip install datasets"
) from e
ds = load_from_disk(DATASET_PATH)
d, split = _choose_split(ds)
cols = list(d.column_names)
mapping = _detect_columns(cols)
# Extract columns
E_all = _to_numpy_column(d, mapping["num_experts"])
P_all = _to_numpy_column(d, mapping["dense_parameter_count"])
y_all = _to_numpy_column(d, mapping["loss_validation"])
# Groups
group_key = mapping["group"]
if group_key == "__ALL__":
groups = {"__ALL__": np.arange(len(y_all))}
else:
group_raw = d.to_dict()[group_key]
# Map unique group names to indices
uniq = {}
for i, g in enumerate(group_raw):
uniq.setdefault(str(g), []).append(i)
groups = {k: np.asarray(v, dtype=int) for k, v in uniq.items()}
coeffs: Dict[str, Coeffs] = {}
meta: Dict[str, Dict[str, float]] = {}
for g, idx in groups.items():
E = E_all[idx]
P = P_all[idx]
y = y_all[idx]
c, r2, mse = _fit_group(E, P, y)
coeffs[g] = c
meta[g] = {"r2": float(r2), "mse": float(mse)}
return coeffs, meta, mapping
def _fit_and_save() -> Dict[str, Coeffs]:
coeffs, meta, mapping = _fit_from_dataset()
_save_params_and_explain(coeffs, meta, mapping)
return coeffs
def _fallback_coeffs(coeffs: Dict[str, Coeffs]) -> Coeffs:
# Median of parameters across groups as a robust fallback.
arr = np.array([[c.L_inf, c.A, c.alpha, c.B, c.beta] for c in coeffs.values()], dtype=float)
if arr.size == 0:
# Reasonable generic defaults (will be poor but defined)
return Coeffs(L_inf=0.5, A=1.0, alpha=0.5, B=1.0, beta=0.5)
med = np.median(arr, axis=0)
return Coeffs(L_inf=float(med[0]), A=float(med[1]), alpha=float(med[2]), B=float(med[3]), beta=float(med[4]))
def law(input_data: list[dict[str, float]], group: str) -> list[dict[str, float]]:
"""
Predicts output variables based on input variables according to a discovered scaling law.
Args:
input_data: A list of dictionaries, where each dictionary is a single data
point containing:
- 'num_experts': float
- 'dense_parameter_count': float
group: The name of the experimental group for which to make predictions.
The functional form is the same for all groups; coefficients differ.
Returns:
A list of dictionaries with a single key:
- 'loss_validation': predicted validation loss (float)
"""
coeffs = _load_params()
if coeffs is None:
# Fit from dataset and persist params + explanation
coeffs = _fit_and_save()
# Choose coefficients for the requested group or a robust fallback
c = coeffs.get(group)
if c is None:
c = _fallback_coeffs(coeffs)
out: List[Dict[str, float]] = []
for row in input_data:
ne = float(row.get("num_experts", 0.0))
dp = float(row.get("dense_parameter_count", 0.0))
out.append({"loss_validation": c.predict(ne, dp)})
return out
if __name__ == "__main__":
# (Re)fit parameters from /app/data and regenerate /app/explain.md
coeffs = _fit_and_save()
print(f"Fitted coefficients for {len(coeffs)} group(s) and wrote:")
print(f" - {PARAMS_PATH}")
print(f" - {EXPLAIN_PATH}")