# EVOLVE-BLOCK-START
import numpy as np
from scipy.optimize import minimize
_P0, _D0, _U0 = 1.1e9, 1.0e12, 5.0e8
_EPS = 1e-12
def _normalize(X):
U = np.clip(X[:, 0] / _U0, _EPS, None)
P = np.clip(X[:, 1] / _P0, _EPS, None)
D = np.clip(X[:, 2] / _D0, _EPS, None)
return U, P, D
def scaling_law_func(data_points, params):
X = np.atleast_2d(np.asarray(data_points, dtype=np.float64))
if X.shape[1] != 3:
raise ValueError("data_points must have 3 columns: [unique_tokens, params, tokens]")
U, P, D = _normalize(X)
p = np.atleast_2d(np.asarray(params, dtype=np.float64))
if p.shape[1] != 7:
raise ValueError("params must have 7 elements: [L0,cP,aP,cD,aD,cU,aU]")
L0, cP, aP, cD, aD, cU, aU = [p[:, i] for i in range(7)]
aP = np.clip(aP, 0.0, None); aD = np.clip(aD, 0.0, None); aU = np.clip(aU, 0.0, None)
lp, ld, lu = np.log(P)[:, None], np.log(D)[:, None], np.log(U)[:, None]
pred = (L0[None, :]
+ cP[None, :] * np.exp(-aP[None, :] * lp)
+ cD[None, :] * np.exp(-aD[None, :] * ld)
+ cU[None, :] * np.exp(-aU[None, :] * lu))
return pred[:, 0] if p.shape[0] == 1 else pred
def fit_scaling_law(data_points, loss_values):
X = np.atleast_2d(np.asarray(data_points, dtype=np.float64))
y = np.asarray(loss_values, dtype=np.float64)
if X.shape[1] != 3:
raise ValueError("data_points must have 3 columns: [unique_tokens, params, tokens]")
y2d = y[:, None] if y.ndim == 1 else y
U, P, D = _normalize(X)
lp, ld, lu = np.log(P), np.log(D), np.log(U)
def pseudo_huber(r, d=0.25):
return d*d * (np.sqrt(1.0 + (r/d)**2) - 1.0)
def ridge_amplitudes(y_col, L0, aP, aD, aU):
Phi = np.stack([np.exp(-aP * lp), np.exp(-aD * ld), np.exp(-aU * lu)], axis=1)
b = y_col - L0
AtA = Phi.T @ Phi + 1e-3 * np.eye(3)
Atb = Phi.T @ b
c = np.linalg.solve(AtA, Atb)
return np.clip(c, 1e-12, 100.0)
def obj_phi(phi, y_col):
L0, aP, aD, aU = phi
aP = max(aP, 0.0); aD = max(aD, 0.0); aU = max(aU, 0.0)
cP, cD, cU = ridge_amplitudes(y_col, L0, aP, aD, aU)
pred = (L0
+ cP * np.exp(-aP * lp)
+ cD * np.exp(-aD * ld)
+ cU * np.exp(-aU * lu))
r = pred - y_col
reg = 1e-6 * (cP*cP + cD*cD + cU*cU + aP*aP + aD*aD + aU*aU)
return np.mean(pseudo_huber(r)) + reg
def make_inits(y_col):
y_min = float(np.min(y_col))
inits = []
for L0 in [max(y_min - 0.1, 0.0), y_min, min(y_min + 0.2, 10.0)]:
for aP, aD, aU in [(0.5,0.5,0.5),(0.8,0.6,0.4),(0.3,0.9,0.4),(1.0,0.4,0.3)]:
inits.append(np.array([L0, aP, aD, aU], dtype=np.float64))
rng = np.random.default_rng(123)
for _ in range(6):
inits.append(np.array([
float(np.clip(y_min + 0.2 * rng.uniform(-1, 1), 0.0, 10.0)),
rng.uniform(0.05, 1.5),
rng.uniform(0.05, 1.5),
rng.uniform(0.05, 1.5)
], dtype=np.float64))
return inits
bounds_phi = [(0.0, 10.0), (0.02, 2.5), (0.02, 2.5), (0.02, 2.5)]
T = y2d.shape[1]
params_all = np.zeros((T, 7), dtype=np.float64)
for t in range(T):
y_col = y2d[:, t]
best_val, best_phi = np.inf, None
for init in make_inits(y_col):
try:
res = minimize(obj_phi, init, args=(y_col,), method="L-BFGS-B",
bounds=bounds_phi, options=dict(maxiter=500, ftol=1e-9))
phi = res.x if res.success else init
val = obj_phi(phi, y_col)
except Exception:
phi, val = init, obj_phi(init, y_col)
if val < best_val:
best_val, best_phi = val, phi
L0, aP, aD, aU = best_phi
cP, cD, cU = ridge_amplitudes(y_col, L0, aP, aD, aU)
theta = np.array([L0, cP, aP, cD, aD, cU, aU], dtype=np.float64)
def obj_full(th):
L0, cP, aP, cD, aD, cU, aU = th
aP = np.clip(aP, bounds_phi[1][0], bounds_phi[1][1])
aD = np.clip(aD, bounds_phi[2][0], bounds_phi[2][1])
aU = np.clip(aU, bounds_phi[3][0], bounds_phi[3][1])
cP = np.clip(cP, 1e-12, 100.0); cD = np.clip(cD, 1e-12, 100.0); cU = np.clip(cU, 1e-12, 100.0)
L0 = np.clip(L0, 0.0, 10.0)
pred = (L0
+ cP * np.exp(-aP * lp)
+ cD * np.exp(-aD * ld)
+ cU * np.exp(-aU * lu))
r = pred - y_col
reg = 1e-6 * (cP*cP + cD*cD + cU*cU + aP*aP + aD*aD + aU*aU)
return np.mean(pseudo_huber(r)) + reg
b_full = [(0.0, 10.0), (1e-12, 100.0), bounds_phi[1], (1e-12, 100.0),
bounds_phi[2], (1e-12, 100.0), bounds_phi[3]]
try:
res2 = minimize(obj_full, theta, method="L-BFGS-B",
bounds=b_full, options=dict(maxiter=300, ftol=1e-9))
theta = res2.x if res2.success else theta
except Exception:
pass
params_all[t, :] = theta
return params_all[0] if T == 1 else params_all
# EVOLVE-BLOCK-END