# EVOLVE-BLOCK-START
"""
Physics-informed scaling law with improved hyperparameter modeling
Key improvements:
1. Chinchilla power laws with data-driven exponent fitting
2. Learning rate modeled via μP-inspired optimal scaling (lr_opt ~ 1/N)
3. Batch size effect via gradient noise scale theory with critical batch
4. Reduced to 9 parameters for better generalization
5. Enhanced optimizer with adaptive bounds and multi-stage refinement
"""
import numpy as np
from scipy.optimize import minimize, differential_evolution
def scaling_law_func(data_points, params):
"""
Scaling law: L = L_inf + A/N^α + B/D^β + lr_penalty + bsz_penalty
Key physics:
- lr_penalty: quadratic deviation from μP-style optimal LR
- bsz_penalty: gradient noise model with critical batch size
"""
X = np.atleast_2d(np.asarray(data_points, dtype=np.float64))
params = np.asarray(params, dtype=np.float64)
if params.ndim == 1:
params = params[None, :]
# Extract and normalize features with safety bounds
lr = np.clip(X[:, 0], 1e-10, 1.0)
bsz = np.clip(X[:, 1], 1.0, 1e8)
D = np.clip(X[:, 2], 1e6, 1e15)
N = np.clip(X[:, 3], 1e6, 1e12)
# Unpack parameters (9 total - balanced complexity)
L_inf = params[:, 0:1].T # Irreducible loss
A = params[:, 1:2].T # Model size coefficient
alpha = params[:, 2:3].T # Model size exponent
B = params[:, 3:4].T # Data coefficient
beta = params[:, 4:5].T # Data exponent
gamma = params[:, 5:6].T # LR penalty scale
lr_exp = params[:, 6:7].T # LR-N coupling (μP theory)
delta = params[:, 7:8].T # BSZ penalty scale
bsz_exp = params[:, 8:9].T # BSZ-D coupling
# Core Chinchilla-style power laws
model_term = A / np.power(N[:, None], alpha)
data_term = B / np.power(D[:, None], beta)
# Learning rate penalty with μP-inspired scaling
# Optimal LR scales inversely with model size: lr_opt ~ N^(-1)
# Base LR of 0.005 is empirically reasonable for standard parameterization
lr_opt = 0.005 * np.power(N[:, None], lr_exp)
lr_ratio = lr[:, None] / np.clip(lr_opt, 1e-10, 1.0)
# Symmetric quadratic penalty in log-space with gentle tails
log_lr_ratio = np.log(lr_ratio)
lr_penalty = gamma * (log_lr_ratio ** 2 + 0.05 * log_lr_ratio ** 4)
# Batch size penalty with gradient noise theory
# Critical batch size grows with data: B_crit ~ D^κ
# Below critical: strong noise penalty; above: mild inefficiency
bsz_crit = 128.0 * np.power(D[:, None] / 1e10, bsz_exp)
bsz_ratio = bsz[:, None] / np.clip(bsz_crit, 8.0, 1e7)
# Asymmetric penalty function
# Small batches (ratio < 1): severe gradient noise
# Large batches (ratio > 1): mild diminishing returns
log_bsz_ratio = np.log(bsz_ratio)
bsz_penalty = delta * np.where(
bsz_ratio < 1.0,
# Strong penalty for small batches: noise dominates
0.5 * (1.0 / bsz_ratio - 1.0) + 0.3 * log_bsz_ratio ** 2,
# Mild penalty for large batches: diminishing returns
0.1 * log_bsz_ratio + 0.05 * log_bsz_ratio ** 2
)
# Total prediction
pred = L_inf + model_term + data_term + lr_penalty + bsz_penalty
return pred[:, 0] if pred.shape[1] == 1 else pred
def fit_scaling_law(data_points, loss_values):
"""
Advanced three-stage fitting: global search → local refinement → final polish
"""
X = np.atleast_2d(np.asarray(data_points, dtype=np.float64))
y = np.asarray(loss_values, dtype=np.float64)
if y.ndim == 1:
y = y[:, None]
T = y.shape[1]
n_params = 9
# Compute data statistics for adaptive bounds
loss_min, loss_max = np.min(y), np.max(y)
loss_range = loss_max - loss_min
loss_std = np.std(y)
loss_median = np.median(y)
# Percentile-based bounds for robustness
loss_p10 = np.percentile(y, 10)
loss_p90 = np.percentile(y, 90)
# Theory-informed parameter bounds
bounds = [
(loss_min - 0.4, loss_p10 + 0.1), # L_inf: near achievable minimum
(0.005, loss_range * 150), # A: wide range for model term
(0.08, 0.65), # alpha: 0.3-0.5 typical, allow broader
(0.005, loss_range * 150), # B: wide range for data term
(0.08, 0.65), # beta: similar to alpha
(0.0, loss_std * 8), # gamma: LR penalty strength
(-1.2, -0.05), # lr_exp: negative (μP theory)
(0.0, loss_std * 6), # delta: BSZ penalty strength
(0.0, 0.3), # bsz_exp: positive (larger D → larger B_crit)
]
def objective(flat_params):
params = flat_params.reshape(T, n_params)
try:
pred = scaling_law_func(X, params)
if pred.ndim == 1:
pred = pred[:, None]
# Robust loss: Huber-style combination
residuals = pred - y
abs_residuals = np.abs(residuals)
# MSE for small errors, MAE for large (outlier robustness)
huber_delta = 0.5 * loss_std
huber_loss = np.where(
abs_residuals <= huber_delta,
0.5 * residuals ** 2,
huber_delta * (abs_residuals - 0.5 * huber_delta)
)
main_loss = np.mean(huber_loss)
# Regularization: prefer Chinchilla-like exponents
reg_alpha = 0.015 * (params[:, 2] - 0.38) ** 2
reg_beta = 0.015 * (params[:, 4] - 0.38) ** 2
# Mild parameter magnitude regularization
reg_l2 = 1e-9 * np.sum(params ** 2)
return main_loss + reg_alpha + reg_beta + reg_l2
except:
return 1e16
# Smart initialization based on low-loss samples
low_loss_mask = y < np.percentile(y, 25)
L_inf_init = np.mean(y[low_loss_mask]) - 0.15 if np.any(low_loss_mask) else loss_min
init_params = np.array([
np.clip(L_inf_init, loss_min - 0.3, loss_p10),
loss_range * 12, # A
0.38, # alpha (Chinchilla default)
loss_range * 10, # B
0.38, # beta
0.4, # gamma
-0.6, # lr_exp (μP-like)
0.25, # delta
0.15, # bsz_exp
])
# Stage 1: Differential evolution with enhanced settings
result_de = differential_evolution(
objective,
bounds=bounds * T,
maxiter=600,
popsize=30,
seed=42,
atol=1e-11,
tol=1e-11,
workers=1,
strategy='best1bin',
mutation=(0.3, 1.3),
recombination=0.85,
polish=False,
init='sobol' # Better space coverage than latinhypercube
)
best_params = result_de.x
best_score = result_de.fun
# Stage 2: L-BFGS-B refinement with multiple restarts
for attempt in range(3):
try:
if attempt == 0:
start_point = best_params
else:
# Add small perturbations for exploration
noise_scale = 0.005 * (2 - attempt)
start_point = best_params + np.random.randn(len(best_params)) * noise_scale
result_lbfgs = minimize(
objective,
start_point,
method='L-BFGS-B',
bounds=bounds * T,
options={'maxiter': 2500, 'ftol': 1e-15, 'gtol': 1e-13}
)
if result_lbfgs.success and result_lbfgs.fun < best_score:
best_params = result_lbfgs.x
best_score = result_lbfgs.fun
except:
continue
# Stage 3: Powell for final polish (unconstrained but verify bounds)
try:
result_powell = minimize(
objective,
best_params,
method='Powell',
options={'maxiter': 1500, 'ftol': 1e-13, 'xtol': 1e-13}
)
if result_powell.success and result_powell.fun < best_score:
# Verify all parameters within bounds
params_check = result_powell.x.reshape(T, n_params)
within_bounds = all(
bounds[i][0] <= params_check[0, i] <= bounds[i][1]
for i in range(n_params)
)
if within_bounds:
best_params = result_powell.x
except:
pass
params_opt = best_params.reshape(T, n_params)
return params_opt[0] if T == 1 else params_opt
# EVOLVE-BLOCK-END