from __future__ import annotations
import json
import math
import os
import re
from dataclasses import dataclass
from typing import Dict, Iterable, List, Tuple
import numpy as np
# Try to import datasets lazily and robustly.
try:
from datasets import load_from_disk, Dataset, DatasetDict # type: ignore
except Exception: # pragma: no cover - optional dependency
load_from_disk = None # type: ignore
Dataset = None # type: ignore
DatasetDict = None # type: ignore
@dataclass
class FittedModel:
# theta: shape (n_targets, n_features), where n_features = 1 (intercept) + n_proportions
theta: np.ndarray
proportion_keys: List[str] # ordered list of feature keys used during fitting
target_keys: List[str] # ordered list of target keys used during fitting
# Global store of fitted coefficients per group.
_FITTED_BY_GROUP: Dict[str, FittedModel] = {}
# Numerical stabilizers
_EPS_P: float = 1e-12
_EPS_L: float = 1e-12
_RIDGE: float = 1e-3
def _safe_log(x: np.ndarray, eps: float) -> np.ndarray:
return np.log(np.clip(x, eps, None))
def _detect_keys(columns: Iterable[str]) -> Tuple[List[str], List[str], str | None]:
cols = list(columns)
# Proportion keys like "proportion_domain_1" ... "proportion_domain_5"
prop_keys = sorted(
[c for c in cols if re.fullmatch(r"proportion_domain_\d+", c)],
key=lambda k: int(k.rsplit("_", 1)[1]),
)
# Target keys like "loss_domain_1" ... "loss_domain_5"
tgt_keys = sorted(
[c for c in cols if re.fullmatch(r"loss_domain_\d+", c)],
key=lambda k: int(k.rsplit("_", 1)[1]),
)
# Group column
group_col: str | None = None
if "group" in cols:
group_col = "group"
else:
# Fallback to any column ending with '_group' or named 'Group'
for cand in cols:
if cand.lower().ends_with("_group") or cand == "Group":
group_col = cand
break
return prop_keys, tgt_keys, group_col
def _dataset_to_rows(ds_obj) -> List[Dict[str, float]]:
# Convert Dataset or DatasetDict to a list of dict rows
rows: List[Dict[str, float]] = []
if DatasetDict is not None and isinstance(ds_obj, DatasetDict):
for split in ds_obj.values():
rows.extend(_dataset_to_rows(split))
return rows
# ds_obj is a Dataset or something iterable over dicts
try:
# Iterating a datasets.Dataset yields dicts row-wise efficiently
for row in ds_obj: # type: ignore
rows.append(row)
except Exception:
# Fallback: try to_dict
try:
data_dict = ds_obj.to_dict() # type: ignore
n = len(next(iter(data_dict.values())))
for i in range(n):
rows.append({k: v[i] for k, v in data_dict.items()})
except Exception:
pass
return rows
def _fit_group(rows: List[Dict[str, float]], proportion_keys: List[str], target_keys: List[str]) -> FittedModel:
# Build X (design) and Y (targets)
X_list: List[List[float]] = []
Y_lists: List[List[float]] = [[] for _ in target_keys]
for r in rows:
try:
p_vec = np.array([float(r[k]) for k in proportion_keys], dtype=float)
if np.any(~np.isfinite(p_vec)):
continue
# Require all targets present and finite
y_vals = []
valid = True
for tk in target_keys:
val = float(r[tk])
if not (np.isfinite(val) and val > 0):
valid = False
break
y_vals.append(val)
if not valid:
continue
except Exception:
continue
x = [1.0]
x.extend(_safe_log(p_vec, _EPS_P).tolist())
X_list.append(x)
for i, y in enumerate(y_vals):
Y_lists[i].append(float(y))
if not X_list or any(len(yc) == 0 for yc in Y_lists):
# Fallback: identity-like tiny coefficients to keep the model running
n_features = 1 + len(proportion_keys)
theta = np.zeros((len(target_keys), n_features), dtype=float)
# Intercepts default to log(1.0) = 0
return FittedModel(theta=theta, proportion_keys=proportion_keys, target_keys=target_keys)
X = np.asarray(X_list, dtype=float) # shape (n_samples, n_features)
n_features = X.shape[1]
XtX = X.T @ X
reg = np.eye(n_features, dtype=float)
reg[0, 0] = 0.0 # do not regularize intercept
XtX_reg = XtX + _RIDGE * reg
theta_rows: List[np.ndarray] = []
Xt = X.T
for y_vals in Y_lists:
y = _safe_log(np.asarray(y_vals, dtype=float), _EPS_L)
Xty = Xt @ y
try:
beta = np.linalg.solve(XtX_reg, Xty)
except np.linalg.LinAlgError:
beta = np.linalg.lstsq(XtX_reg, Xty, rcond=None)[0]
theta_rows.append(beta)
theta = np.vstack(theta_rows) # (n_targets, n_features)
return FittedModel(theta=theta, proportion_keys=proportion_keys, target_keys=target_keys)
def _fit_all_groups() -> None:
global _FITTED_BY_GROUP
if _FITTED_BY_GROUP:
return # already fit
# Attempt to load the dataset from disk
rows: List[Dict[str, float]] = []
if load_from_disk is not None:
try:
ds = load_from_disk("/app/data")
rows = _dataset_to_rows(ds)
except Exception:
rows = []
# If no rows, leave empty and we will populate a default model
all_columns = set()
for r in rows:
all_columns.update(r.keys())
prop_keys, tgt_keys, group_col = _detect_keys(all_columns)
# Ensure we have expected 5 domains; if not, try to infer from any present keys
if not prop_keys:
# Default to proportion_domain_1..5 if not present
prop_keys = [f"proportion_domain_{i}" for i in range(1, 6)]
if not tgt_keys:
tgt_keys = [f"loss_domain_{i}" for i in range(1, 6)]
if not rows:
# Build a default "ALL" model
_FITTED_BY_GROUP["ALL"] = _fit_group([], prop_keys, tgt_keys)
_write_explain_md(_FITTED_BY_GROUP)
return
# Partition rows by group (or single ALL group)
grouped: Dict[str, List[Dict[str, float]]] = {}
if group_col is None:
grouped["ALL"] = rows
else:
for r in rows:
g = r.get(group_col, "ALL")
grouped.setdefault(str(g), []).append(r)
# Fit per group
for g, gr_rows in grouped.items():
_FITTED_BY_GROUP[g] = _fit_group(gr_rows, prop_keys, tgt_keys)
# Also fit a global ALL group across everything for fallback
if "ALL" not in _FITTED_BY_GROUP:
_FITTED_BY_GROUP["ALL"] = _fit_group(rows, prop_keys, tgt_keys)
# Write explanation file including fitted parameters
_write_explain_md(_FITTED_BY_GROUP)
def _write_explain_md(fitted: Dict[str, FittedModel]) -> None:
# Prepare a deterministic JSON-like dump of parameters
payload = {}
for g, fm in fitted.items():
payload[g] = {
"proportion_keys": fm.proportion_keys,
"target_keys": fm.target_keys,
"theta": np.asarray(fm.theta, dtype=float).round(8).tolist(),
"model": "log-linear power law: log(loss_i) = theta[i,0] + sum_j theta[i,j] * log(p_j + eps); eps = %.0e"
% _EPS_P,
"ridge": _RIDGE,
}
section = [
"<!-- PARAMS START -->",
"Fitted parameter tensors by group (JSON):",
"",
"```json",
json.dumps(payload, indent=2),
"```",
"<!-- PARAMS END -->",
]
section_text = "\n".join(section) + "\n"
path = "/app/explain.md"
# If explain.md exists, replace the PARAM section; else, create a full file.
try:
if os.path.exists(path):
with open(path, "r", encoding="utf-8") as f:
content = f.read()
new_content: str
if "<!-- PARAMS START -->" in content and "<!-- PARAMS END -->" in content:
new_content = re.sub(
r"<!-- PARAMS START -->.*?<!-- PARAMS END -->",
section_text,
content,
flags=re.DOTALL,
)
else:
new_content = content.rstrip() + "\n\n" + section_text
with open(path, "w", encoding="utf-8") as f:
f.write(new_content)
else:
with open(path, "w", encoding="utf-8") as f:
f.write(_default_explain_md_header().rstrip() + "\n\n" + section_text)
except Exception:
# Best-effort; ignore file writing errors.
pass
def _default_explain_md_header() -> str:
return """# Scaling law for domain mixture in language model pre-training
This document describes the discovered scaling law that predicts per-domain validation loss from the domain mixture proportions used during pre-training.
## Functional form (shared across groups)
We model each domain's validation loss as a multiplicative power-law function of the mixture proportions:
- For domain i in {1..5}, with mixture proportions p_j for j in {1..5}:
loss_i = A_i * Π_j (p_j + ε)^{β_{i,j}}
Equivalently in log-space (which is what we fit):
- log(loss_i) = θ_{i,0} + Σ_j θ_{i,j} * log(p_j + ε)
where A_i = exp(θ_{i,0}) and β_{i,j} = θ_{i,j}.
We fix ε = 1e-12 for numerical stability when a proportion is zero.
We estimate parameters with ridge-regularized least squares on the log-transformed variables, with L2 regularization λ = 1e-3 applied to non-intercept weights.
This form captures:
- Diminishing returns (via negative exponents).
- Cross-domain transfer (exponents β_{i,j} coupling domains).
- Scale invariance with respect to multiplicative changes in the mixture.
## Methodology
- Load the dataset at /app/data using datasets.load_from_disk().
- Identify input features: proportion_domain_1..proportion_domain_5.
- Identify targets: loss_domain_1..loss_domain_5.
- If a 'group' column is present, fit a separate parameter set per group; otherwise, fit a single ALL group.
- Optimize θ for each target independently using ridge regression in log space.
- Use the fitted θ to predict losses for new inputs by exponentiating the linear predictor.
## Fitted parameters per group
The exact fitted coefficients depend on the dataset available at runtime.
The block below is automatically populated by /app/law.py when the module is imported and the model is fit.
"""
# Fit on import so that law() can immediately use the parameters and /app/explain.md is populated.
_fit_all_groups()
def law(input_data: list[dict[str, float]], group: str) -> list[dict[str, float]]:
"""
Predicts output variables based on input variables according to a discovered scaling law.
Args:
input_data: A list of dictionaries, where each dictionary is a single data
point containing input variable names as keys and their
corresponding values.
group: The name of the experimental group for which to make predictions.
The functional form of the law must be the same for all groups,
but the constant parameters/coefficients can differ per group.
Returns:
A list of dictionaries, corresponding to the input_data list, with each
dictionary containing the predicted output variable(s).
"""
# Ensure models are fit (idempotent)
_fit_all_groups()
# Choose fitted coefficients for the requested group; fallback to "ALL"
fm = _FITTED_BY_GROUP.get(group) or _FITTED_BY_GROUP.get("ALL")
if fm is None:
# Last-resort fallback with default keys and zero coefficients
proportion_keys = [f"proportion_domain_{i}" for i in range(1, 6)]
target_keys = [f"loss_domain_{i}" for i in range(1, 6)]
fm = FittedModel(theta=np.zeros((len(target_keys), 1 + len(proportion_keys))), # type: ignore
proportion_keys=proportion_keys,
target_keys=target_keys)
# Build predictions
out: List[Dict[str, float]] = []
prop_keys = fm.proportion_keys
theta = np.asarray(fm.theta, dtype=float)
# If incoming dicts have a different set of proportion keys, try to realign
# to the canonical order based on numeric suffix.
def canonicalize_keys(keys: List[str]) -> List[str]:
return sorted(keys, key=lambda k: int(k.rsplit("_", 1)[1]) if re.fullmatch(r".*_\d+", k) else math.inf)
for row in input_data:
# If row has all canonical keys, use them; else, try to align using suffix.
if not all(k in row for k in prop_keys):
candidate = [k for k in row.keys() if re.fullmatch(r"proportion_domain_\d+", k)]
if candidate:
prop_keys_runtime = canonicalize_keys(candidate)
else:
prop_keys_runtime = prop_keys # fall back
else:
prop_keys_runtime = prop_keys
p_vec = np.array([float(row.get(k, 0.0)) for k in prop_keys_runtime], dtype=float)
x = np.empty(1 + p_vec.size, dtype=float)
x[0] = 1.0
x[1:] = _safe_log(p_vec, _EPS_P)
# Predict each target independently in log-space then exponentiate.
yhat_log = theta @ x # shape (n_targets,)
yhat = np.exp(yhat_log)
# Produce outputs with canonical target keys (loss_domain_1..5)
pred: Dict[str, float] = {}
# Map predictions to fm.target_keys order; also ensure we output exactly loss_domain_1..5
for idx, tk in enumerate(fm.target_keys):
pred[tk] = float(yhat[idx])
# If any expected loss_domain_i missing (e.g., dataset had different naming), fill canonical keys
for i in range(1, 6):
key = f"loss_domain_{i}"
if key not in pred and idx < yhat.size:
pred[key] = float(yhat[min(i - 1, yhat.size - 1)])
out.append(pred)
return out