# EVOLVE-BLOCK-START
"""
U-shaped (double-descent) scaling law with 6 parameters and robust, density-balanced fitting.
Model (6 params):
pred(x) = b0 + A_up * sigmoid((x - c1)/s_l) - A_dn * sigmoid((x - c2)/s_r)
where
A_up = softplus(A_up_raw) >= 0 (early degradation amplitude)
A_dn = softplus(A_dn_raw) >= 0 (later improvement amplitude)
c1 = c0 - 0.5 * d, c2 = c0 + 0.5 * d (ordered centers)
s = softplus(s_raw) + eps (base width > 0)
d = s * (1 + softplus(k_raw)) (separation >= s)
s_l = s * g, s_r = s / g (asymmetry with g = (1 + 0.5*softplus(k_raw))**0.5)
The single k_raw parameter controls both separation and a mild left/right width asymmetry,
capturing skewed U-shapes without increasing parameter count beyond 6.
Fitting improvements:
- Robust Huber loss with inverse-density weights along x (log_flops) to avoid bias from uneven sampling.
- Analytic re-centering of b0 at each evaluation (weighted mean residual given other params) for stability.
- Gentle priors: keep amplitudes/widths moderate, ensure meaningful separation, encourage ultimate improvement.
- Multi-start seeding (grid + jitter), L-BFGS-B refinement, least-squares polish with robust soft_l1 loss,
and a final Powell polish.
"""
import numpy as np
from scipy.optimize import minimize, least_squares
def _softplus(z):
z = np.asarray(z, dtype=float)
return np.log1p(np.exp(-np.abs(z))) + np.maximum(z, 0.0)
def _sigmoid(z):
z = np.clip(z, -60.0, 60.0)
return 1.0 / (1.0 + np.exp(-z))
def scaling_law_func(data_points, params):
X = np.atleast_2d(np.asarray(data_points, dtype=float))
x = X[:, 0][:, None] # (N,1)
p = np.asarray(params, dtype=float)
if p.ndim == 1:
p = p[None, :]
# params: [b0, A_up_raw, A_dn_raw, c0, s_raw, k_raw]
b0 = p[:, 0][None, :]
A_up = _softplus(p[:, 1])[None, :]
A_dn = _softplus(p[:, 2])[None, :]
c0 = p[:, 3][None, :]
s = _softplus(p[:, 4])[None, :] + 1e-3
# Single factor k_raw controls separation and mild asymmetry
ksp = _softplus(p[:, 5])[None, :] # >=0
d = s * (1.0 + ksp) # separation >= s
# asymmetry factor g in [1, sqrt(1+0.5*ksp)+], mild skew flexibility without extra params
g = np.sqrt(1.0 + 0.5 * ksp)
s_l = s * g
s_r = s / np.maximum(g, 1e-6)
c1 = c0 - 0.5 * d
c2 = c0 + 0.5 * d
z1 = (x - c1) / s_l
z2 = (x - c2) / s_r
pred = b0 + A_up * _sigmoid(z1) - A_dn * _sigmoid(z2)
return pred[:, 0] if pred.shape[1] == 1 else pred
def fit_scaling_law(data_points, loss_values):
X = np.atleast_2d(np.asarray(data_points, dtype=float))
y = np.asarray(loss_values, dtype=float)
N, F = X.shape
assert F == 1, "Expected single feature: log_flops"
y2d = y[:, None] if y.ndim == 1 else y
T = y2d.shape[1]
x = X[:, 0].astype(float)
x_min, x_max = float(np.min(x)), float(np.max(x))
xr = max(float(x_max - x_min), 1e-6)
# Inverse-density weights over x via quantile binning
Q = min(20, max(8, N // 25))
qs = np.linspace(0.0, 1.0, Q + 1)
edges = np.quantile(x, qs)
# ensure strictly increasing edges
edges[0] -= 1e-9
edges[-1] += 1e-9
bin_idx = np.clip(np.searchsorted(edges, x, side='right') - 1, 0, Q - 1)
counts = np.maximum(1, np.bincount(bin_idx, minlength=Q))
w = 1.0 / counts[bin_idx]
w = (N * w) / np.sum(w) # normalize to sum N
def huber(res, delta):
a = np.abs(res)
return np.where(a <= delta, 0.5 * res**2, delta * (a - 0.5 * delta))
def inv_softplus(v):
v = np.clip(v, 1e-6, 50.0)
return np.log(np.expm1(v))
# Objective utilities
b0_bounds = (-6.0, 0.5)
def decode_pred_no_b0(pvec):
# Compute model prediction with b0 forced to 0 to allow analytic centering of b0
pv = np.array(pvec, dtype=float)
pv0 = pv.copy()
pv0[0] = 0.0
return scaling_law_func(X, pv0)
def b0_optimal(pvec, yi):
# Weighted mean residual (approximate MSE-optimal b0 under weights)
f0 = decode_pred_no_b0(pvec)
num = np.sum(w * (yi - f0))
den = np.sum(w)
b0hat = num / max(den, 1e-9)
return float(np.clip(b0hat, b0_bounds[0], b0_bounds[1]))
def add_b0(pvec, b0hat):
pv = np.array(pvec, dtype=float).copy()
pv[0] = b0hat
return pv
def objective(pvec, yi):
# Analytic b0 centering for stability
b0hat = b0_optimal(pvec, yi)
p_used = add_b0(pvec, b0hat)
pred = scaling_law_func(X, p_used)
r = pred - yi
# Robust delta via MAD
mad = np.median(np.abs(yi - np.median(yi)))
delta = 1.4826 * mad if mad > 1e-8 else 0.1
loss = np.mean(w * huber(r, delta))
# Gentle priors/regularization
A_up = _softplus(pvec[1])
A_dn = _softplus(pvec[2])
s = _softplus(pvec[4]) + 1e-3
ksp = _softplus(pvec[5])
sep = s * (1.0 + ksp)
# Encourage meaningful separation and improvement >= degradation
reg = 3e-4 * (A_up**2 + A_dn**2 + s**2)
reg += 1.5e-4 * _softplus(A_up - A_dn) # push A_dn >= A_up
reg += 2.0e-4 * _softplus(0.3 * s - sep) # sep >= 0.3*s
# Keep transitions inside observed x-range
c0 = pvec[3]
c1 = c0 - 0.5 * sep
c2 = c0 + 0.5 * sep
reg += 1.0e-4 * (_softplus(x_min - c1) + _softplus(c2 - x_max))
# Discourage positive brier (should be negative); penalize positive asymptote too
tail = b0hat + A_up - A_dn
reg += 1.0e-5 * (np.mean(np.maximum(pred, 0.0)) + _softplus(tail))
return loss + reg
# Residual function for least_squares polishing (soft_l1 robust)
def residuals(pvec, yi):
b0hat = b0_optimal(pvec, yi)
p_used = add_b0(pvec, b0hat)
pred = scaling_law_func(X, p_used)
res = np.sqrt(w) * (pred - yi)
# Append small regularization terms as residuals
A_up = _softplus(pvec[1])
A_dn = _softplus(pvec[2])
s = _softplus(pvec[4]) + 1e-3
ksp = _softplus(pvec[5])
sep = s * (1.0 + ksp)
c0 = pvec[3]
c1 = c0 - 0.5 * sep
c2 = c0 + 0.5 * sep
reg_terms = np.array([
1e-2 * (A_up - A_dn), # encourage A_dn >= A_up
1e-2 * max(0.3 * s - sep, 0.0), # ensure separation
5e-3 * max(x_min - c1, 0.0),
5e-3 * max(c2 - x_max, 0.0),
2e-3 * max(b0hat + A_up - A_dn, 0.0) # positive tail penalty
], dtype=float)
return np.concatenate([res, reg_terms])
# Parameter bounds
bnds = [
b0_bounds, # b0
(-8.0, 8.0), # A_up_raw
(-8.0, 8.0), # A_dn_raw
(x_min - 0.2, x_max + 0.2), # c0
(-8.0, 8.0), # s_raw
(-8.0, 8.0), # k_raw
]
low = np.array([b[0] for b in bnds], dtype=float)
high = np.array([b[1] for b in bnds], dtype=float)
# Smoothing helper to locate peak
order = np.argsort(x)
x_sorted = x[order]
def smooth(vals, w_frac=0.06):
k = max(5, int(w_frac * N))
v = np.asarray(vals, dtype=float)[order]
pad = np.pad(v, (k//2, k - 1 - k//2), mode='edge')
ker = np.ones(k, dtype=float) / k
return np.convolve(pad, ker, mode='valid')
q25, q40, q50, q60, q75 = np.quantile(x, [0.25, 0.40, 0.50, 0.60, 0.75])
rng = np.random.default_rng(123)
params_opt = np.zeros((T, 6), dtype=float)
for ti in range(T):
yi = y2d[:, ti]
# Initialization
sm = smooth(yi, 0.06)
idx_peak = int(np.clip(np.argmax(sm), 0, N - 1))
c0_init = float(x_sorted[idx_peak]) if np.isfinite(idx_peak) else float(q50)
y_low = float(np.mean(yi[x <= q25])) if np.any(x <= q25) else float(np.mean(yi))
y_high = float(np.mean(yi[x >= q75])) if np.any(x >= q75) else float(np.mean(yi))
y_peak = float(sm[idx_peak]) if np.isfinite(idx_peak) else float(np.median(yi))
b0_init = float(np.median(yi[x <= q25])) if np.any(x <= q25) else float(np.median(yi))
Aup0 = max(0.02, y_peak - y_low)
Adn0 = max(0.02, y_peak - y_high)
s0 = max(0.08, 0.20 * xr)
d0 = max(0.25 * xr, 1.1 * s0)
base = np.array([
b0_init,
inv_softplus(0.8 * Aup0),
inv_softplus(0.8 * Adn0),
c0_init,
inv_softplus(s0),
inv_softplus(max(1.0, d0 / s0 - 1.0)),
], dtype=float)
# Seed pool: grid over c0, s and amplitude scales + jitter
seeds = []
for cg in (c0_init, float(q40), float(q60)):
for sg in (0.18 * xr, 0.28 * xr, 0.38 * xr):
for asc in (0.7, 1.0, 1.3):
seeds.append(np.array([
b0_init,
inv_softplus(asc * Aup0),
inv_softplus(asc * Adn0),
cg,
inv_softplus(max(0.05, sg)),
inv_softplus(max(1.0, d0 / max(0.05, sg)) - 1.0),
], dtype=float))
seeds.append(base)
for _ in range(8):
jitter = np.array([
rng.normal(0, 0.05),
rng.normal(0, 0.25),
rng.normal(0, 0.25),
rng.normal(0, 0.15 * xr),
rng.normal(0, 0.25),
rng.normal(0, 0.25),
], dtype=float)
seeds.append(base + jitter)
# Score seeds quickly
scored = []
for s in seeds:
s = np.clip(s, low, high)
try:
scored.append(objective(s, yi))
except Exception:
scored.append(np.inf)
top_idx = np.argsort(scored)[:min(10, len(seeds))]
best_p, best_val = None, np.inf
for idx in top_idx:
init = np.clip(seeds[idx], low, high)
# L-BFGS-B refinement
res = minimize(objective, init, args=(yi,), method='L-BFGS-B', bounds=bnds,
options={'maxiter': 800, 'ftol': 1e-9})
cand_p = res.x if res.success else init
val = objective(cand_p, yi)
if val < best_val:
best_val, best_p = val, cand_p
# Robust least-squares polish
try:
ls = least_squares(lambda pv: residuals(pv, yi), best_p, method='trf',
bounds=(low, high), loss='soft_l1', f_scale=1.0,
max_nfev=800, xtol=1e-9, ftol=1e-9)
if ls.success:
best_p = ls.x
except Exception:
pass
# Final Powell polish
res2 = minimize(objective, best_p, args=(yi,), method='Powell',
options={'maxiter': 500, 'ftol': 1e-7})
if res2.success and res2.fun <= objective(best_p, yi):
best_p = res2.x
# Set final b0 analytically
b0hat = b0_optimal(best_p, yi)
best_p = add_b0(best_p, b0hat)
params_opt[ti] = best_p
return params_opt[0] if T == 1 else params_opt
# EVOLVE-BLOCK-END