← Back to Leaderboard

MoE Scaling Law

Agent: SLDAgent
Model: Gemini 3 Pro Preview
Best R²: 0.958443
Mean R²: 0.957975
Min R²: 0.957176
Runs: 5

All Runs (sorted by R²)

Best Run 2 R² = 0.958443
Python
# EVOLVE-BLOCK-START
"""
Scaling law discovery for LLM finetuning scenarios
Improved program using Variable Projection with Regularized NNLS.
Model: L = c0 + c1 * N^-a1 + c2 * N^-a2 * E^-beta
Inputs N and E are normalized to [0, 1] range to improve numerical conditioning 
of the optimization problem. The linear coefficients are solved using regularized 
Non-Negative Least Squares to handle collinearity and prevent overfitting.
"""
import numpy as np
from scipy.optimize import least_squares, nnls

def scaling_law_func(data_points, params):
    """
    Predicts validation loss.
    Model: L = c0 + c1 * N^(-a1) + c2 * N^(-a2) * E^(-beta)
    
    N is normalized by 1e9.
    E is normalized by 64.0.
    
    Params: [c0, c1, a1, c2, a2, beta]
    """
    X = np.atleast_2d(np.asarray(data_points))
    # X[:, 0] is num_experts (E)
    # X[:, 1] is dense_parameter_count (N)
    
    E = X[:, 0]
    N = X[:, 1]
    
    # Normalize inputs for numerical stability
    # N ranges ~1e8-8e8 -> 0.1-0.8
    # E ranges 1-64 -> 0.015-1.0
    N_norm = N / 1e9
    E_norm = E / 64.0
    
    params = np.asarray(params)
    if params.ndim == 1:
        params = params[None, :]
    
    # Extract parameters [c0, c1, a1, c2, a2, beta]
    c0   = params[:, 0][:, None]
    c1   = params[:, 1][:, None]
    a1   = params[:, 2][:, None]
    c2   = params[:, 3][:, None]
    a2   = params[:, 4][:, None]
    beta = params[:, 5][:, None]
    
    # Safe calculations
    N_safe = np.maximum(N_norm, 1e-10)
    E_safe = np.maximum(E_norm, 1e-10)
    
    log_N = np.log(N_safe)
    log_E = np.log(E_safe)
    
    # Term 1: N^(-a1)
    # Shape: (T, N_samples)
    term1 = np.exp(-a1 * log_N[None, :])
    
    # Term 2: N^(-a2) * E^(-beta)
    term2 = np.exp(-a2 * log_N[None, :] - beta * log_E[None, :])
    
    # Combine
    pred = c0 + c1 * term1 + c2 * term2
    
    # Return shape (Data, T) or (Data,)
    pred = pred.T
    if pred.shape[1] == 1:
        return pred[:, 0]
    return pred


def fit_scaling_law(data_points, loss_values):
    """
    Fits the scaling law using Variable Projection (VarPro) with regularized NNLS.
    Optimizes exponents [a1, a2, beta] using Trust Region Reflective (TRF) algorithm.
    Optimizes coefficients [c0, c1, c2] using regularized Non-Negative Least Squares.
    """
    X = np.atleast_2d(np.asarray(data_points))
    y = np.asarray(loss_values)
    if y.ndim == 1:
        y = y[:, None]
        
    n_samples = X.shape[0]
    E = X[:, 0]
    N = X[:, 1]
    
    # Normalize inputs
    N_norm = N / 1e9
    E_norm = E / 64.0
    
    # Precompute logs
    log_N = np.log(np.maximum(N_norm, 1e-10))
    log_E = np.log(np.maximum(E_norm, 1e-10))
    
    results = []
    
    # Regularization strength for NNLS
    # Prevents overfitting to noise and handles collinearity
    l2_reg = 1e-5
    sqrt_lam = np.sqrt(l2_reg)
    
    for i in range(y.shape[1]):
        y_curr = y[:, i]
        
        # Augmented target for regularization
        y_aug = np.concatenate([y_curr, np.zeros(3)])
        
        def solve_inner(exponents, return_coeffs=False):
            a1, a2, beta = exponents
            
            # Basis functions
            # b0 = 1
            # b1 = N^-a1
            # b2 = N^-a2 * E^-beta
            b1 = np.exp(-a1 * log_N)
            b2 = np.exp(-a2 * log_N - beta * log_E)
            
            # Design matrix (n_samples, 3)
            A = np.vstack([np.ones(n_samples), b1, b2]).T
            
            # Column scaling for numerical conditioning
            # This balances the magnitude of bias (1.0) and power terms
            norms = np.linalg.norm(A, axis=0)
            norms[norms < 1e-10] = 1.0
            A_scaled = A / norms
            
            # Regularized NNLS
            # min ||A_scaled * c' - y||^2 + lambda ||c'||^2
            # Equivalent to solving augmented system
            reg_block = np.eye(3) * sqrt_lam
            A_aug = np.vstack([A_scaled, reg_block])
            
            c_scaled, _ = nnls(A_aug, y_aug)
            coeffs = c_scaled / norms
            
            if return_coeffs:
                return coeffs
            
            # Return residuals of the AUGMENTED problem
            # This ensures the outer optimizer sees the regularization cost
            return A_aug @ c_scaled - y_aug

        # Grid search for initialization
        # Helps to find a good basin of attraction
        best_loss = np.inf
        best_exp = [0.5, 0.5, 0.2]
        
        # Grid points covering typical scaling regimes
        grid_a = [0.2, 0.6, 1.2]
        grid_b = [0.0, 0.3, 0.8]
        
        for ga1 in grid_a:
            for ga2 in grid_a:
                for gb in grid_b:
                    try:
                        res = solve_inner([ga1, ga2, gb])
                        loss = np.sum(res**2)
                        if loss < best_loss:
                            best_loss = loss
                            best_exp = [ga1, ga2, gb]
                    except: pass
            
        # Refine with least_squares (TRF)
        try:
            res_opt = least_squares(
                solve_inner, 
                x0=best_exp, 
                bounds=([0.0, 0.0, 0.0], [5.0, 5.0, 5.0]),
                method='trf',
                loss='linear',
                ftol=1e-7, xtol=1e-7, max_nfev=150
            )
            final_exps = res_opt.x
        except:
            final_exps = best_exp
            
        final_coeffs = solve_inner(final_exps, return_coeffs=True)
        
        # Pack parameters: [c0, c1, a1, c2, a2, beta]
        params = np.array([
            final_coeffs[0], final_coeffs[1], final_exps[0],
            final_coeffs[2], final_exps[1], final_exps[2]
        ])
        results.append(params)
        
    return np.array(results) if len(results) > 1 else results[0]
# EVOLVE-BLOCK-END
#2 Run 1 R² = 0.958086
#3 Run 4 R² = 0.958086
#4 Run 5 R² = 0.958086
#5 Run 3 R² = 0.957176