from __future__ import annotations
import math
from typing import Dict, List
import numpy as np
_DATA_PATH = "/app/data"
_EPS = 1e-12 # numeric floor for logs
_MODELS = None # lazy-fit cache: {"per_group": {group: model}, "global": model}
def _safe_log(x: float) -> float:
return math.log(max(float(x), _EPS))
def _detect_column(columns: List[str], preferred: str, alternatives: List[str]) -> str:
if preferred in columns:
return preferred
for alt in alternatives:
if alt in columns:
return alt
raise KeyError(
f"Expected column '{preferred}' not found. Available: {columns}. "
f"Tried alternatives: {alternatives}"
)
def _load_dataset_dicts() -> List[dict]:
"""
Loads the dataset from disk and returns it as a list of dicts.
Supports either a Dataset or a DatasetDict (uses 'train' split if present,
otherwise the first available split).
"""
try:
from datasets import load_from_disk
except Exception as e:
raise RuntimeError(
"Failed to import 'datasets'. Ensure the 'datasets' package is installed."
) from e
ds = load_from_disk(_DATA_PATH)
# Normalize to a Dataset instance
try:
# DatasetDict (mapping of splits)
if hasattr(ds, "keys"):
if "train" in ds:
ds = ds["train"]
else:
# Pick the first available split
first_key = next(iter(ds.keys()))
ds = ds[first_key]
except Exception:
pass
# Convert to Python list of records without requiring pandas
# HuggingFace Dataset supports to_dict() returning column-wise dict of lists
cols = ds.column_names
coldict = ds.to_dict()
n = len(next(iter(coldict.values()))) if coldict else 0
records = []
for i in range(n):
rec = {c: coldict[c][i] for c in cols}
records.append(rec)
return records
def _fit_group_models() -> Dict[str, dict]:
"""
Fit a log-linear (power-law) model per group:
ln(L) = b0 + b1 ln(lr) + b2 ln(bsz) + b3 ln(data_size) + b4 ln(non_embedding_param_size)
Returns a dict with:
{
"per_group": { group_name: {"beta": np.array, "rmse": float, "n": int} },
"global": {"beta": np.array, "rmse": float, "n": int}
}
"""
data = _load_dataset_dicts()
if not data:
# No data found; return a sane fallback
beta_fallback = np.array([math.log(3.5), 0.0, 0.0, 0.0, 0.0], dtype=float)
return {
"per_group": {},
"global": {"beta": beta_fallback, "rmse": float("nan"), "n": 0},
}
# Detect columns
columns = list(data[0].keys())
# Required numeric columns (use exact names specified by the user prompt, with a few safe fallbacks)
col_lr = _detect_column(columns, "lr", ["learning_rate"])
col_bsz = _detect_column(columns, "bsz", ["batch_size"])
col_data = _detect_column(columns, "data_size", ["tokens", "num_tokens", "dataset_size"])
col_params = _detect_column(
columns, "non_embedding_param_size", ["non_embedding_params", "non_embedding_parameters"]
)
# Target
col_loss = _detect_column(columns, "lm_loss", ["loss", "val_loss", "final_loss"])
# Group column (optional)
group_col = None
for cand in ["group", "group_name", "exp_group", "dataset_group", "task_group"]:
if cand in columns:
group_col = cand
break
# Filter invalid rows and build per-group buckets
buckets: Dict[str, List[dict]] = {}
for row in data:
try:
lr = float(row[col_lr])
bsz = float(row[col_bsz])
size = float(row[col_data])
params = float(row[col_params])
loss = float(row[col_loss])
except Exception:
continue
if not (lr > 0 and bsz > 0 and size > 0 and params > 0 and loss > 0):
continue
g = str(row[group_col]) if group_col is not None else "default"
buckets.setdefault(g, []).append(
{ # keep only needed keys
"lr": lr,
"bsz": bsz,
"data_size": size,
"non_embedding_param_size": params,
"lm_loss": loss,
}
)
# If no valid groups, fallback
if not buckets:
beta_fallback = np.array([math.log(3.5), 0.0, 0.0, 0.0, 0.0], dtype=float)
return {
"per_group": {},
"global": {"beta": beta_fallback, "rmse": float("nan"), "n": 0},
}
def _fit(X: np.ndarray, y: np.ndarray) -> Dict[str, float | np.ndarray]:
# Solve least squares
beta, *_ = np.linalg.lstsq(X, y, rcond=None)
pred = X @ beta
rmse = float(math.sqrt(np.mean((pred - y) ** 2))) if y.size else float("nan")
return {"beta": beta, "rmse": rmse, "n": int(y.size)}
models_per_group: Dict[str, dict] = {}
# Pooled/global data
X_all, y_all = [], []
for g, rows in buckets.items():
# Build design matrix: [1, ln(lr), ln(bsz), ln(data_size), ln(non_embedding_param_size)]
Xg, yg = [], []
for r in rows:
Xg.append(
[
1.0,
_safe_log(r["lr"]),
_safe_log(r["bsz"]),
_safe_log(r["data_size"]),
_safe_log(r["non_embedding_param_size"]),
]
)
yg.append(_safe_log(r["lm_loss"]))
Xg = np.asarray(Xg, dtype=float)
yg = np.asarray(yg, dtype=float)
model_g = _fit(Xg, yg)
models_per_group[g] = model_g
X_all.append(Xg)
y_all.append(yg)
X_all = np.vstack(X_all) if X_all else np.zeros((0, 5))
y_all = np.concatenate(y_all) if y_all else np.zeros((0,))
model_global = _fit(X_all, y_all) if y_all.size else {
"beta": np.array([math.log(3.5), 0.0, 0.0, 0.0, 0.0], dtype=float),
"rmse": float("nan"),
"n": 0,
}
return {"per_group": models_per_group, "global": model_global}
def _ensure_models():
global _MODELS
if _MODELS is None:
try:
_MODELS = _fit_group_models()
except Exception:
# Absolute fallback to avoid hard failure during import/prediction
beta_fallback = np.array([math.log(3.5), 0.0, 0.0, 0.0, 0.0], dtype=float)
_MODELS = {"per_group": {}, "global": {"beta": beta_fallback, "rmse": float("nan"), "n": 0}}
def _predict_one(row: Dict[str, float], group: str) -> float:
"""
Predict lm_loss for a single row using the fitted model for `group`.
"""
_ensure_models()
assert _MODELS is not None
model = _MODELS["per_group"].get(group) or _MODELS["global"]
beta: np.ndarray = model["beta"]
# feature order must match training: [1, ln(lr), ln(bsz), ln(data_size), ln(non_embedding_param_size)]
try:
x = [
1.0,
_safe_log(row["lr"]),
_safe_log(row["bsz"]),
_safe_log(row["data_size"]),
_safe_log(row["non_embedding_param_size"]),
]
except KeyError as e:
missing = str(e).strip("'")
raise KeyError(
f"Missing required input variable '{missing}'. "
"Expected keys: lr, bsz, data_size, non_embedding_param_size"
)
y_log = float(np.dot(beta, np.asarray(x, dtype=float)))
# Ensure strictly positive prediction
return max(math.exp(y_log), _EPS)
def get_fitted_params() -> Dict[str, dict]:
"""
Returns a dictionary with fitted parameters per group and global model.
{
"per_group": {
group: {
"beta": [b0, b1, b2, b3, b4],
"A": exp(b0),
"rmse": ...,
"n": ...
}, ...
},
"global": { ... }
}
"""
_ensure_models()
assert _MODELS is not None
# Add derived A = exp(b0)
def enrich(model: dict) -> dict:
beta = np.asarray(model["beta"], dtype=float)
return {
"beta": beta.tolist(),
"A": float(math.exp(beta[0])),
"rmse": float(model.get("rmse", float("nan"))),
"n": int(model.get("n", 0)),
}
per_group = {g: enrich(m) for g, m in _MODELS["per_group"].items()}
global_m = enrich(_MODELS["global"])
return {"per_group": per_group, "global": global_m}
def dump_explanation(path: str = "/app/explain.md") -> None:
"""
Writes a detailed explanation and the fitted parameters per group to a Markdown file.
"""
params = get_fitted_params()
lines = []
lines.append("# Scaling Law Explanation")
lines.append("")
lines.append("We model the final language modeling loss as a multiplicative power-law in the")
lines.append("training hyperparameters, fit in log-space using ordinary least squares (OLS):")
lines.append("")
lines.append(" ln(L) = b0 + b1 ln(lr) + b2 ln(bsz) + b3 ln(data_size) + b4 ln(non_embedding_param_size)")
lines.append("")
lines.append("Equivalently, in the original scale:")
lines.append("")
lines.append(" L = A * lr^b1 * bsz^b2 * data_size^b3 * non_embedding_param_size^b4, where A = exp(b0)")
lines.append("")
lines.append("We fit one set of coefficients per experimental group (same functional form for all groups).")
lines.append("")
lines.append("## Fitted Coefficients")
lines.append("")
lines.append("| Group | A (exp(b0)) | b1 (lr) | b2 (bsz) | b3 (data_size) | b4 (non_emb_params) | RMSE (ln L) | N |")
lines.append("|---|---:|---:|---:|---:|---:|---:|---:|")
def fmt(x: float) -> str:
if math.isnan(x):
return "NaN"
return f"{x:.6g}"
for g, m in sorted(params["per_group"].items()):
b0, b1, b2, b3, b4 = m["beta"]
A = m["A"]
lines.append(
f"| {g} | {fmt(A)} | {fmt(b1)} | {fmt(b2)} | {fmt(b3)} | {fmt(b4)} | {fmt(m['rmse'])} | {m['n']} |"
)
# Global/pool summary
gm = params["global"]
b0, b1, b2, b3, b4 = gm["beta"]
lines.append("")
lines.append("### Global (pooled) fit")
lines.append("")
lines.append(f"- A = {fmt(gm['A'])}")
lines.append(f"- b1 (lr) = {fmt(b1)}, b2 (bsz) = {fmt(b2)}, b3 (data_size) = {fmt(b3)}, b4 (non_emb_params) = {fmt(b4)}")
lines.append(f"- RMSE (ln L) = {fmt(gm['rmse'])}, N = {gm['n']}")
lines.append("")
lines.append("## Notes")
lines.append("- Rows with non-positive values are excluded prior to log transform.")
lines.append("- Predictions for unseen groups fall back to the global pooled model.")
lines.append("- This simple log-linear law is robust and extrapolates smoothly; if stronger")
lines.append(" curvature is present, augmenting with interaction terms is a straightforward extension.")
content = "\n".join(lines)
try:
with open(path, "w", encoding="utf-8") as f:
f.write(content)
except Exception:
# Best-effort; ignore write failures
pass
def law(input_data: List[Dict[str, float]], group: str) -> List[Dict[str, float]]:
"""
Predicts output variables based on input variables according to a discovered scaling law.
Args:
input_data: A list of dictionaries, where each dictionary is a single data
point containing input variable names as keys and their
corresponding values. Required keys:
- lr
- bsz
- data_size
- non_embedding_param_size
group: The name of the experimental group for which to make predictions.
The functional form of the law is the same for all groups,
but the constant parameters/coefficients can differ per group.
Returns:
A list of dictionaries, corresponding to the input_data list, with each
dictionary containing the predicted output variable(s): {"lm_loss": float}.
"""
preds = []
for row in input_data:
pred = _predict_one(row, group)
preds.append({"lm_loss": float(pred)})
return preds