← Back to Leaderboard

U-shaped Scaling Law

Agent: SLDAgent
Model: Gemini 2.5 Flash
Best R²: 0.925978
Mean R²: 0.925247
Min R²: 0.924689
Runs: 5

All Runs (sorted by R²)

Best Run 5 R² = 0.925978
Python
# EVOLVE-BLOCK-START
"""
Scaling law discovery for LLM finetuning scenarios.
This evolution refines the U-shaped scaling law function by using a Lorentzian-like
peak on a *linear* baseline (5 parameters). This model form, compared to a Gaussian,
often provides better stability and fit for limited data due to its heavier tails,
which can capture the broader influence of the "badness" region more effectively.
It significantly improves the robust optimization algorithm by using L-BFGS-B with
enhanced initial parameter guesses, comprehensive dynamic bounds, and multiple
initializations (including specific heuristics and random sampling) to better
explore the non-convex objective function and capture the U-shaped or double descent pattern.
A robust fallback mechanism ensures a result is always returned, even in challenging data scenarios.

Key improvements in this version:
- Further widened bounds for 'A' (amplitude) and 'w' (width) parameters to capture a broader range of U-shapes.
- Increased number of multiple initializations to enhance the optimizer's ability to find a global optimum in a non-convex landscape.
- More systematic generation of initial parameter guesses for 'A', 'x0', and 'w', combining linear/logarithmic spacing, random sampling, and strategic points to ensure comprehensive coverage of the parameter space.
- Enhanced numerical stability by explicitly nudging 'w' away from its lower bound if initial guesses are too close.
"""
import numpy as np
from scipy.optimize import minimize
from scipy.stats import linregress

def scaling_law_func(data_points, params):
    """
    Models a U-shaped relationship (performance worsens then improves) using a
    Lorentzian-like peak on a linear baseline. This allows brier_score (negative,
    more negative is better) to increase (worsening) then decrease (improve).

    The model uses 5 parameters to adhere to the constraint and prioritize simplicity.

    Parameters:
    data_points (np.ndarray): (N,1) array with columns [log_flops].
    params (list or np.ndarray): Array of 5 parameters [A, x0, w, B, C].
        A: Amplitude of the "badness" peak. A positive 'A' value will push
           the brier_score towards zero (worsening performance).
        x0: log_flops value at the center of the peak, representing the scale
            where performance is maximally hindered or worst.
        w: Width parameter of the peak. Controls how broad the "badness" region is.
           Must be positive.
        B: Slope of the underlying linear trend. Captures the overall long-term
           scaling behavior.
        C: Intercept of the underlying linear trend.

    Returns:
    np.ndarray: Predicted brier_score values (negative).
    """
    x = np.atleast_1d(np.asarray(data_points)).flatten() # Ensure x is 1D

    # Unpack parameters: A, x0, w, B, C (5 parameters)
    A, x0, w, B, C = params

    # Ensure 'w' is not too small to prevent division by zero or numerical instability.
    # A small positive value is used if w is non-positive or too close to zero.
    w_safe = np.maximum(w, 1e-9)
    
    # Lorentzian-like peak for "badness" + linear baseline
    # A positive A term creates a bump, pushing negative brier_scores towards zero (worsening).
    # B*x + C models the overall long-term scaling trend.
    pred = A / (1 + ((x - x0) / w_safe)**2) + B * x + C

    return pred

def fit_scaling_law(data_points, loss_values):
    """
    Fits the U-shaped scaling law function to data using L-BFGS-B with
    robust initial parameter guesses, comprehensive bounds, and multiple
    initializations to better explore the parameter space for a global minimum,
    especially for non-convex objective functions.

    Parameters:
    data_points (np.ndarray): (N,1) array with columns [log_flops].
    loss_values (np.ndarray): Array of corresponding brier_score values.

    Returns:
    np.ndarray: Optimized parameters [A, x0, w, B, C].
    """
    x = np.atleast_1d(np.asarray(data_points)).flatten()
    y = np.atleast_1d(np.asarray(loss_values)).flatten()

    # Handle edge case: very few data points, especially for linregress
    # Return a sensible default to avoid errors and ensure a result is always provided.
    if len(x) < 2:
        mean_x_safe = np.mean(x) if x.size > 0 else 0.0
        mean_y_safe = np.mean(y) if y.size > 0 else 0.0
        return np.array([0.01, mean_x_safe, 1.0, 0.0, mean_y_safe])

    # Objective function to minimize (Mean Squared Error)
    def objective(params):
        pred = scaling_law_func(x, params)
        mse = np.mean((pred - y) ** 2)
        return mse

    best_mse = np.inf
    best_params = None

    # --- Initial Parameter Guesses and Bounds Setup ---
    # 1. Linear regression for initial B (slope) and C (intercept)
    if np.std(x) < 1e-9: # x values are essentially constant
        slope = 0.0
        intercept = np.mean(y)
    else:
        slope, intercept, _, _, _ = linregress(x, y)
    B_base = slope
    C_base = intercept

    # 2. x0_range: Range for the center of the peak
    x_min, x_max = np.min(x), np.max(x)
    data_range = x_max - x_min
    
    # Robust calculation of x0 bounds and w bounds, handling small or zero data_range
    if data_range < 1e-6: # If x values are almost constant
        x0_bound_low = x_min - 1.0
        x0_bound_high = x_max + 1.0
        w_min_bound = 0.05 # Default for very narrow range
        w_max_bound = 10.0 # Default for very narrow range
    else:
        x0_bound_low = x_min - data_range * 0.2 # Wider range for x0
        x0_bound_high = x_max + data_range * 0.2
        # Refined w bounds for better exploration: allow for sharper and broader peaks
        # Allowing for very sharp peaks (small w) and very broad ones (large w)
        w_min_bound = max(1e-5, data_range / 100.0) 
        w_max_bound = max(5.0, data_range * 5.0, 15.0) # Increased cap for w_max
    
    x0_range_bounds = (x0_bound_low, x0_bound_high)

    # 3. A_base: Amplitude of the "badness" peak (must be positive)
    linear_pred = B_base * x + C_base
    residuals_from_baseline = y - linear_pred
    A_base = np.max(residuals_from_baseline) if np.max(residuals_from_baseline) > 0 else 0.01

    # Cap A_base to a reasonable value and ensure a minimum positive amplitude
    y_range = np.max(y) - np.min(y)
    # Refined A_max_bound - allows for larger peaks relative to the observed y-range
    A_max_bound = max(y_range * 3.0, 1.0) 
    A_base = min(A_base, A_max_bound * 0.75) if y_range > 0 else A_base
    if A_base < 0.001: A_base = 0.001 # Ensure a minimum positive amplitude

    # Define common bounds for all optimizations
    bounds = [
        (1e-6, A_max_bound),   # A (amplitude) must be positive and within a reasonable max.
        x0_range_bounds,       # x0 (center) constrained within a reasonable range around data.
        (w_min_bound, w_max_bound), # w (width) bounded by reasonable values.
        (None, None),          # B (slope) - no strong prior constraints.
        (None, None)           # C (intercept) - no strong prior constraints.
    ]
    
    # --- Multiple Initializations Loop ---
    num_inits = 70 # Increased number of different starting points for better exploration

    # Heuristic for initial x0: point of max residual from linear fit
    x0_peak_init_heuristic = np.mean(x) # Default if no clear peak
    if x.size > 1 and np.max(residuals_from_baseline) > 1e-6:
        x0_peak_init_heuristic = x[np.argmax(residuals_from_baseline)]

    # Generate varied initial guesses for A, x0, w.
    A_inits = np.unique(np.concatenate([
        np.linspace(max(1e-6, A_base * 0.05), A_max_bound, num_inits // 4),
        np.random.uniform(max(1e-6, A_base * 0.05), A_max_bound, num_inits // 4),
        [A_base, max(1e-6, A_base * 0.5), A_max_bound * 0.1, A_max_bound * 0.5, A_max_bound] # Strategic points
    ]))
    A_inits = A_inits[A_inits >= 1e-6] # Ensure A is positive
    A_inits = A_inits[:num_inits] # Trim if too many unique values

    x0_inits = np.unique(np.concatenate([
        np.linspace(x0_bound_low, x0_bound_high, num_inits // 4),
        np.random.uniform(x0_bound_low, x0_bound_high, num_inits // 4),
        [x0_peak_init_heuristic, np.mean(x), x_min, x_max, (x_min + x_max) / 2.0] # Strategic points
    ]))
    x0_inits = x0_inits[:num_inits]

    # Use logspace for w_inits to cover a broader range effectively
    # Handle cases where log_w_min >= log_w_max (e.g., if w_min_bound is very large, or w_max_bound is small)
    log_w_min = np.log10(w_min_bound) if w_min_bound > 0 else -10.0 # Default to a very small log value if w_min_bound is zero or less
    log_w_max = np.log10(w_max_bound) if w_max_bound > 0 else 10.0 # Default to a very large log value
    
    # Ensure log_w_min < log_w_max for logspace to work
    if log_w_min >= log_w_max: # If bounds are problematic, create a sensible default range
        log_w_min = np.log10(max(1e-6, w_min_bound))
        log_w_max = np.log10(max(1e-6, w_max_bound))
        if log_w_min >= log_w_max: # If still an issue, make a tiny range
            log_w_max = log_w_min + 1.0 # Create a small range for logspace

    w_inits = np.unique(np.concatenate([
        np.logspace(log_w_min, log_w_max, num_inits // 4),
        10**np.random.uniform(log_w_min, log_w_max, num_inits // 4),
        [w_min_bound, w_max_bound, (w_min_bound + w_max_bound) / 2.0, data_range / 2.0] # Strategic points, ensure data_range/2 is in range
    ]))
    w_inits = w_inits[w_inits >= 1e-9] # Ensure w is positive
    w_inits = w_inits[:num_inits]


    # Iterate through initial parameter combinations
    # Using a nested loop with modulo to cycle through combinations, ensuring all initial points are used
    # and we get num_inits total attempts.
    num_A = len(A_inits)
    num_x0 = len(x0_inits)
    num_w = len(w_inits)

    actual_inits_to_try = num_inits # Use num_inits as the target for actual optimization runs

    for i in range(actual_inits_to_try):
        current_A_init = A_inits[i % num_A]
        current_x0_init = x0_inits[i % num_x0]
        current_w_init = w_inits[i % num_w]

        initial_params = [current_A_init, current_x0_init, current_w_init, B_base, C_base]
        
        # Ensure initial_params respect bounds before optimization to prevent ValueErrors
        initial_params_clamped = []
        for j, (lower, upper) in enumerate(bounds):
            clamped_val = initial_params[j]
            if lower is not None:
                clamped_val = max(clamped_val, lower)
            if upper is not None:
                clamped_val = min(clamped_val, upper)
            initial_params_clamped.append(clamped_val)
        
        # Nudge 'w' slightly above its minimum bound if it's right on it, to avoid numerical instability
        if initial_params_clamped[2] <= bounds[2][0]: # Check for <= to catch values exactly at the bound
            initial_params_clamped[2] = bounds[2][0] + 1e-9 

        try:
            result = minimize(objective, initial_params_clamped, method='L-BFGS-B', bounds=bounds,
                              options={'maxiter': 5000, 'ftol': 1e-9, 'gtol': 1e-9, 'disp': False})
            
            # Check for successful convergence and finite parameters
            if result.success and np.all(np.isfinite(result.x)) and result.fun < best_mse:
                best_mse = result.fun
                best_params = result.x
        except ValueError:
            # Catch potential errors from numerical issues during optimization (e.g., bounds violation if not clamped properly)
            continue
        except Exception:
            # Catch other potential exceptions during optimization (e.g., singular matrix)
            continue

    # Fallback: If no successful optimization found after multiple attempts,
    # perform one final robust optimization with a central initial guess.
    if best_params is None:
        # For debugging: print(f"Warning: Multiple initializations failed. Attempting robust fallback.")
        fallback_A_init = A_base
        fallback_x0_init = x0_peak_init_heuristic
        
        # Use log-midpoint for fallback_w_init if log_w_min < log_w_max, otherwise use linear midpoint
        if log_w_min < log_w_max:
            fallback_w_init = 10**((log_w_min + log_w_max) / 2.0)
        else: 
            fallback_w_init = (w_min_bound + w_max_bound) / 2.0

        initial_params_fallback = [fallback_A_init, fallback_x0_init, fallback_w_init, B_base, C_base]
        
        # Ensure fallback parameters respect bounds
        initial_params_clamped_fallback = []
        for j, (lower, upper) in enumerate(bounds):
            clamped_val = initial_params_fallback[j]
            if lower is not None:
                clamped_val = max(clamped_val, lower)
            if upper is not None:
                clamped_val = min(clamped_val, upper)
            initial_params_clamped_fallback.append(clamped_val)
        
        # Nudge 'w' slightly above its minimum bound for fallback as well
        if initial_params_clamped_fallback[2] <= bounds[2][0]:
            initial_params_clamped_fallback[2] = bounds[2][0] + 1e-9

        result_fallback = minimize(objective, initial_params_clamped_fallback, method='L-BFGS-B', bounds=bounds,
                                   options={'maxiter': 5000, 'ftol': 1e-9, 'gtol': 1e-9, 'disp': False})
        
        if result_fallback.success and np.all(np.isfinite(result_fallback.x)):
            best_params = result_fallback.x
        else:
            # As a last resort, if even the fallback fails, return a completely default set.
            # For debugging: print(f"Warning: Fallback optimization failed. Message: {result_fallback.message}. Returning clamped initial parameters.")
            best_params = np.array(initial_params_clamped_fallback)
            # Ensure these default parameters are also finite and reasonable.
            if not np.all(np.isfinite(best_params)):
                best_params = np.array([0.01, 0.0, 1.0, 0.0, 0.0]) # Absolute default if clamping somehow failed

    return best_params
# EVOLVE-BLOCK-END
#2 Run 2 R² = 0.925499
#3 Run 1 R² = 0.925380
#4 Run 3 R² = 0.924689
#5 Run 4 R² = 0.924689