# EVOLVE-BLOCK-START
"""
Advanced MoE scaling law with coupled parameter-expert scaling.
Form: L = a * N^b * (1 + c * E^d) + e / sqrt(1 + E) + f
Captures multiplicative interaction between parameters and experts with efficiency gains,
plus routing overhead that decays with expert count.
"""
import numpy as np
from scipy.optimize import minimize, differential_evolution
def scaling_law_func(data_points, params):
"""
Coupled MoE scaling law:
L = a * N^b * (1 + c * E^d) + e / sqrt(1 + E) + f
Components:
- a * N^b * (1 + c * E^d): Parameter scaling modulated by expert count
The (1 + c * E^d) term captures how experts modify parameter efficiency
- e / sqrt(1 + E): Routing overhead with square root decay
- f: Baseline irreducible loss
6 parameters: [a, b, c, d, e, f]
"""
X = np.atleast_2d(np.asarray(data_points, dtype=np.float64))
params = np.asarray(params, dtype=np.float64).flatten()
if len(params) < 6:
params = np.pad(params, (0, 6 - len(params)), constant_values=0.0)
params = params[:6]
num_experts = X[:, 0]
dense_params = X[:, 1]
a, b, c, d, e, f = params
# Numerical safety
E = np.maximum(num_experts, 1.0)
N = np.maximum(dense_params, 1e7)
# Constrain exponents for numerical stability
b_safe = np.clip(b, -0.8, 0.3)
d_safe = np.clip(d, -1.0, 1.0)
# Main scaling term with expert modulation
base_scaling = np.abs(a) * np.power(N, b_safe)
expert_modulation = 1.0 + c * np.power(E, d_safe)
term1 = base_scaling * expert_modulation
# Routing overhead with sqrt decay (faster than logarithmic, slower than power law)
term2 = np.abs(e) / np.sqrt(1.0 + E)
# Baseline loss
term3 = f
pred = term1 + term2 + term3
return pred
def fit_scaling_law(data_points, loss_values):
"""
Robust multi-stage optimization with extensive search.
"""
X = np.atleast_2d(np.asarray(data_points, dtype=np.float64))
y = np.asarray(loss_values, dtype=np.float64).flatten()
# Statistical analysis
y_mean = np.mean(y)
y_min = np.min(y)
y_max = np.max(y)
y_std = np.std(y)
loss_range = y_max - y_min
def objective(params):
try:
pred = scaling_law_func(X, params)
residuals = pred - y
mse = np.mean(residuals ** 2)
# Light regularization
reg = 1e-10 * np.sum(params ** 2)
return mse + reg
except:
return 1e12
# Carefully tuned bounds based on MoE theory
bounds = [
(1e-14, 1e14), # a: main coefficient
(-0.65, 0.15), # b: parameter exponent (negative)
(-2.0, 2.0), # c: expert interaction coefficient
(-0.8, 0.8), # d: expert interaction exponent
(0.0, loss_range * 3), # e: routing overhead
(y_min * 0.4, y_max * 1.1) # f: baseline
]
# Diverse initialization strategies informed by top performers
initializations = [
# Strategy 1: Strong negative expert interaction (experts improve efficiency)
np.array([y_mean * 3e8, -0.27, -0.15, 0.4, loss_range * 0.6, y_min * 0.95]),
# Strategy 2: Positive expert interaction (experts add overhead)
np.array([y_mean * 8e8, -0.32, 0.25, 0.5, loss_range * 0.4, y_min * 1.0]),
# Strategy 3: Minimal expert effect
np.array([y_mean * 5e8, -0.25, -0.02, 0.1, loss_range * 0.5, y_mean * 0.9]),
# Strategy 4: Strong parameter scaling, weak expert
np.array([y_mean * 1e9, -0.35, 0.05, 0.3, loss_range * 0.3, y_min * 0.92]),
# Strategy 5: Balanced approach
np.array([y_mean * 4e8, -0.28, -0.08, 0.35, loss_range * 0.55, y_mean * 0.88]),
# Strategy 6: High routing overhead
np.array([y_mean * 6e8, -0.24, -0.12, 0.6, loss_range * 0.9, y_min * 0.97]),
# Strategy 7: Minimal overhead, strong expert
np.array([y_mean * 2e9, -0.30, -0.20, 0.45, loss_range * 0.2, y_min * 0.93]),
]
best_params = None
best_loss = float('inf')
# Phase 1: Multi-start local optimization
for init in initializations:
try:
result = minimize(
objective,
init,
method='L-BFGS-B',
bounds=bounds,
options={
'maxiter': 5000,
'ftol': 1e-14,
'gtol': 1e-11,
'maxfun': 6000
}
)
if result.fun < best_loss:
best_loss = result.fun
best_params = result.x
except:
continue
# Phase 2: Global search if local optimization underperforms
if best_params is None or best_loss > 0.5 * y_std ** 2:
try:
result_de = differential_evolution(
objective,
bounds,
maxiter=550,
popsize=22,
seed=42,
atol=1e-12,
tol=1e-12,
workers=1,
strategy='best1bin',
init='latinhypercube'
)
if result_de.fun < best_loss:
best_params = result_de.x
best_loss = result_de.fun
# Phase 3: Polish with local optimization
try:
result_polish = minimize(
objective,
best_params,
method='L-BFGS-B',
bounds=bounds,
options={'maxiter': 3000, 'ftol': 1e-14, 'gtol': 1e-11}
)
if result_polish.success and result_polish.fun < best_loss:
best_params = result_polish.x
except:
pass
except:
pass
# Fallback to reasonable default
if best_params is None:
best_params = np.array([y_mean * 3e8, -0.27, -0.1, 0.4, loss_range * 0.5, y_min * 0.95])
return best_params
# EVOLVE-BLOCK-END