"""
Scaling law predictor for U-shaped scaling of Brier score vs. log_flops.
We model, for each experimental group g:
brier_score ≈ a_g * (log_flops - c_g) ** 2 + b_g
with a_g ≥ 0 enforcing a U-shaped curve that achieves its minimum at log_flops = c_g.
Parameters are estimated once from the dataset at /app/data (if available) using:
- A coarse-to-fine grid search over c_g (vertex) within the observed range of log_flops
- Closed-form ordinary least squares for (a_g, b_g) given c_g via the design [ (x - c_g)^2, 1 ]
- A small non-negativity clamp on a_g to maintain the U-shape
If the dataset is not available at runtime, a conservative global default is used.
"""
from __future__ import annotations
from typing import Dict, List, Tuple, Iterable
import math
DATASET_PATH = "/app/data"
GLOBAL_KEY = "__GLOBAL__"
# Small positive curvature to ensure U-shaped behavior even in degenerate cases
_A_MIN = 1e-12
# Cache of fitted parameters per group: {group: {"a": float, "b": float, "c": float}}
_PARAMETERS: Dict[str, Dict[str, float]] = {}
# Cache of per-group sample sizes and fit quality (for optional introspection/printing)
_FIT_STATS: Dict[str, Dict[str, float]] = {}
def _iter_rows_from_hf(ds) -> Iterable[dict]:
"""Yield rows from a HuggingFace Dataset or DatasetDict in a safe, unified way."""
try:
from datasets import Dataset, DatasetDict
except Exception:
# If datasets isn't available, nothing to iterate
return []
if hasattr(ds, "keys") and callable(getattr(ds, "keys", None)):
# Likely a DatasetDict
# Prefer 'train' if present; otherwise iterate all splits
if "train" in ds:
for row in ds["train"]:
yield row
else:
for split in ds.keys():
for row in ds[split]:
yield row
else:
# Single Dataset
for row in ds:
yield row
def _safe_float(v) -> float | None:
try:
f = float(v)
if math.isfinite(f):
return f
return None
except Exception:
return None
def _ols_two_feature(zs: List[float], ys: List[float]) -> Tuple[float, float]:
"""
Closed-form OLS for y ≈ a * z + b given feature z and intercept.
Returns (a, b). Uses numerically stable sums and handles degeneracies.
"""
n = len(zs)
if n == 0:
return (_A_MIN, 0.0)
sz = 0.0
szz = 0.0
sy = 0.0
syz = 0.0
for z, y in zip(zs, ys):
sz += z
szz += z * z
sy += y
syz += y * z
det = szz * n - sz * sz
if abs(det) > 0.0:
inv00 = n / det
inv01 = -sz / det
inv10 = -sz / det
inv11 = szz / det
a = inv00 * syz + inv01 * sy
b = inv10 * syz + inv11 * sy
else:
# Degenerate: fall back to mean-based slope if possible
mean_z = sz / n if n else 0.0
mean_y = sy / n if n else 0.0
num = 0.0
den = 0.0
for z, y in zip(zs, ys):
dz = z - mean_z
num += (y - mean_y) * dz
den += dz * dz
a = (num / den) if den > 0.0 else 0.0
b = mean_y - a * mean_z
# Enforce non-negativity on a (U-shape opening upwards). If clamped, adjust b optimally.
if a < _A_MIN:
a = _A_MIN
# Optimal b given fixed a minimizes MSE: b = mean(y - a*z)
b = (sy - a * sz) / n if n else b
return (a, b)
def _fit_u_shape(xs: List[float], ys: List[float]) -> Tuple[float, float, float, float]:
"""
Fit y ≈ a*(x - c)^2 + b with a ≥ 0 using:
- Grid search over c in [min(xs), max(xs)]
- Closed-form OLS for (a, b) given c
Returns (a, b, c, mse).
"""
n = len(xs)
if n == 0:
return (_A_MIN, 0.0, 0.0, float("inf"))
x_min = min(xs)
x_max = max(xs)
if not math.isfinite(x_min) or not math.isfinite(x_max):
return (_A_MIN, 0.0, 0.0, float("inf"))
# If all xs equal, set c to that value and fit a,b
if x_max == x_min:
c = x_min
zs = [(x - c) ** 2 for x in xs]
a, b = _ols_two_feature(zs, ys)
mse = sum((a * z + b - y) ** 2 for z, y in zip(zs, ys)) / n
return (a, b, c, mse)
# Coarse-to-fine grid for c
# Coarse grid
best = (float("inf"), _A_MIN, 0.0, (x_min + x_max) * 0.5) # (mse, a, b, c)
for num in (41, 81, 161): # progressively finer
best_mse, best_a, best_b, best_c = best
if num <= 1:
candidates = [best_c]
else:
step = (x_max - x_min) / (num - 1)
candidates = [x_min + i * step for i in range(num)]
for c in candidates:
zs = [(x - c) ** 2 for x in xs]
a, b = _ols_two_feature(zs, ys)
mse = sum((a * z + b - y) ** 2 for z, y in zip(zs, ys)) / n
if mse < best_mse:
best = (mse, a, b, c)
# Narrow the search window around current best c for the next iteration
best_mse, best_a, best_b, best_c = best
span = (x_max - x_min) * 0.25
x_min = max(min(best_c - span, best_c), min(xs))
x_max = min(max(best_c + span, best_c), max(xs))
mse, a, b, c = best
# Final small local refinement around best c
local_span = (max(xs) - min(xs)) * 0.05
if local_span > 0:
local_candidates = [c + t * local_span for t in (-1.0, -0.5, 0.0, 0.5, 1.0)]
for c2 in local_candidates:
zs = [(x - c2) ** 2 for x in xs]
a2, b2 = _ols_two_feature(zs, ys)
mse2 = sum((a2 * z + b2 - y) ** 2 for z, y in zip(zs, ys)) / n
if mse2 < mse:
mse, a, b, c = mse2, a2, b2, c2
return (a, b, c, mse)
def _fit_parameters_from_dataset() -> Tuple[Dict[str, Dict[str, float]], Dict[str, Dict[str, float]]]:
"""
Load the dataset, fit parameters per group, and return:
(parameters, fit_stats)
"""
params: Dict[str, Dict[str, float]] = {}
stats: Dict[str, Dict[str, float]] = {}
try:
from datasets import load_from_disk # type: ignore
ds = load_from_disk(DATASET_PATH)
except Exception:
# Dataset not available; return empty and let caller handle defaults
return (params, stats)
# Collect per-group data
per_group_xs: Dict[str, List[float]] = {}
per_group_ys: Dict[str, List[float]] = {}
all_xs: List[float] = []
all_ys: List[float] = []
for row in _iter_rows_from_hf(ds):
x = _safe_float(row.get("log_flops"))
y = _safe_float(row.get("brier_score"))
if x is None or y is None:
continue
g = row.get("group")
# Fallbacks if 'group' column is absent
if g is None:
g = row.get("Group") or row.get("dataset") or row.get("family") or GLOBAL_KEY
g = str(g)
per_group_xs.setdefault(g, []).append(x)
per_group_ys.setdefault(g, []).append(y)
all_xs.append(x)
all_ys.append(y)
# Fit per group
for g, xs in per_group_xs.items():
ys = per_group_ys[g]
a, b, c, mse = _fit_u_shape(xs, ys)
params[g] = {"a": float(a), "b": float(b), "c": float(c)}
stats[g] = {"n": float(len(xs)), "mse": float(mse)}
# Fit global as fallback
if all_xs:
a, b, c, mse = _fit_u_shape(all_xs, all_ys)
params.setdefault(GLOBAL_KEY, {"a": float(a), "b": float(b), "c": float(c)})
stats.setdefault(GLOBAL_KEY, {"n": float(len(all_xs)), "mse": float(mse)})
return (params, stats)
def _ensure_fitted() -> None:
"""Fit parameters once per process if not already fitted."""
global _PARAMETERS, _FIT_STATS
if _PARAMETERS:
return
params, stats = _fit_parameters_from_dataset()
if params:
_PARAMETERS = params
_FIT_STATS = stats
else:
# Dataset missing; set conservative defaults
_PARAMETERS = {
GLOBAL_KEY: {"a": 1e-3, "b": 0.2, "c": 0.0},
}
_FIT_STATS = {GLOBAL_KEY: {"n": 0.0, "mse": float("nan")}}
def _predict_one(x: float, p: Dict[str, float]) -> float:
a = p["a"]
b = p["b"]
c = p["c"]
return a * (x - c) ** 2 + b
def law(input_data: list[dict[str, float]], group: str) -> list[dict[str, float]]:
"""
Predicts output variables based on input variables according to a discovered scaling law.
Args:
input_data: A list of dictionaries, where each dictionary is a single data
point containing input variable names as keys and their
corresponding values.
group: The name of the experimental group for which to make predictions.
The functional form of the law must be the same for all groups,
but the constant parameters/coefficients can differ per group.
Returns:
A list of dictionaries, corresponding to the input_data list, with each
dictionary containing the predicted output variable(s).
"""
_ensure_fitted()
# Pick group-specific parameters if available; otherwise global fallback
params = _PARAMETERS.get(group)
if params is None:
params = _PARAMETERS.get(GLOBAL_KEY)
# As an extra guard, if even global is missing, synthesize a trivial fallback
if params is None:
params = {"a": 1e-3, "b": 0.2, "c": 0.0}
outputs: List[Dict[str, float]] = []
for row in input_data:
x = row.get("log_flops")
xf = _safe_float(x)
if xf is None:
# If log_flops is missing or invalid, return NaN to signal unusable input
outputs.append({"brier_score": float("nan")})
else:
yhat = _predict_one(xf, params)
outputs.append({"brier_score": float(yhat)})
return outputs
if __name__ == "__main__":
# Optional CLI to inspect fitted parameters and fit quality per group.
_ensure_fitted()
# Pretty print results in a stable order
groups = sorted(_PARAMETERS.keys())
# Column header
print("group\tn\ta\tb\tc\tmse")
for g in groups:
p = _PARAMETERS[g]
s = _FIT_STATS.get(g, {})
print(
f"{g}\t{s.get('n', float('nan')):.0f}\t{p['a']:.6g}\t{p['b']:.6g}\t{p['c']:.6g}\t{s.get('mse', float('nan')):.6g}"
)