# EVOLVE-BLOCK-START
"""
Scaling law discovery for LLM finetuning scenarios.
This evolution refines the U-shaped scaling law function by using a Lorentzian-like
peak on a *linear* baseline (5 parameters). This model form, compared to a Gaussian,
often provides better stability and fit for limited data due to its heavier tails,
which can capture the broader influence of the "badness" region more effectively.
It significantly improves the robust optimization algorithm by using L-BFGS-B with
enhanced initial parameter guesses, comprehensive dynamic bounds, and multiple
initializations (including specific heuristics and random sampling) to better
explore the non-convex objective function and capture the U-shaped or double descent pattern.
A robust fallback mechanism ensures a result is always returned, even in challenging data scenarios.
Key improvements in this version:
- Further widened bounds for 'A' (amplitude) and 'w' (width) parameters to capture a broader range of U-shapes.
- Increased number of multiple initializations to enhance the optimizer's ability to find a global optimum in a non-convex landscape.
- More systematic generation of initial parameter guesses for 'A', 'x0', and 'w', combining linear/logarithmic spacing, random sampling, and strategic points to ensure comprehensive coverage of the parameter space.
- Enhanced numerical stability by explicitly nudging 'w' away from its lower bound if initial guesses are too close.
"""
import numpy as np
from scipy.optimize import minimize
from scipy.stats import linregress
def scaling_law_func(data_points, params):
"""
Models a U-shaped relationship (performance worsens then improves) using a
Lorentzian-like peak on a linear baseline. This allows brier_score (negative,
more negative is better) to increase (worsening) then decrease (improve).
The model uses 5 parameters to adhere to the constraint and prioritize simplicity.
Parameters:
data_points (np.ndarray): (N,1) array with columns [log_flops].
params (list or np.ndarray): Array of 5 parameters [A, x0, w, B, C].
A: Amplitude of the "badness" peak. A positive 'A' value will push
the brier_score towards zero (worsening performance).
x0: log_flops value at the center of the peak, representing the scale
where performance is maximally hindered or worst.
w: Width parameter of the peak. Controls how broad the "badness" region is.
Must be positive.
B: Slope of the underlying linear trend. Captures the overall long-term
scaling behavior.
C: Intercept of the underlying linear trend.
Returns:
np.ndarray: Predicted brier_score values (negative).
"""
x = np.atleast_1d(np.asarray(data_points)).flatten() # Ensure x is 1D
# Unpack parameters: A, x0, w, B, C (5 parameters)
A, x0, w, B, C = params
# Ensure 'w' is not too small to prevent division by zero or numerical instability.
# A small positive value is used if w is non-positive or too close to zero.
w_safe = np.maximum(w, 1e-9)
# Lorentzian-like peak for "badness" + linear baseline
# A positive A term creates a bump, pushing negative brier_scores towards zero (worsening).
# B*x + C models the overall long-term scaling trend.
pred = A / (1 + ((x - x0) / w_safe)**2) + B * x + C
return pred
def fit_scaling_law(data_points, loss_values):
"""
Fits the U-shaped scaling law function to data using L-BFGS-B with
robust initial parameter guesses, comprehensive bounds, and multiple
initializations to better explore the parameter space for a global minimum,
especially for non-convex objective functions.
Parameters:
data_points (np.ndarray): (N,1) array with columns [log_flops].
loss_values (np.ndarray): Array of corresponding brier_score values.
Returns:
np.ndarray: Optimized parameters [A, x0, w, B, C].
"""
x = np.atleast_1d(np.asarray(data_points)).flatten()
y = np.atleast_1d(np.asarray(loss_values)).flatten()
# Handle edge case: very few data points, especially for linregress
# Return a sensible default to avoid errors and ensure a result is always provided.
if len(x) < 2:
mean_x_safe = np.mean(x) if x.size > 0 else 0.0
mean_y_safe = np.mean(y) if y.size > 0 else 0.0
return np.array([0.01, mean_x_safe, 1.0, 0.0, mean_y_safe])
# Objective function to minimize (Mean Squared Error)
def objective(params):
pred = scaling_law_func(x, params)
mse = np.mean((pred - y) ** 2)
return mse
best_mse = np.inf
best_params = None
# --- Initial Parameter Guesses and Bounds Setup ---
# 1. Linear regression for initial B (slope) and C (intercept)
if np.std(x) < 1e-9: # x values are essentially constant
slope = 0.0
intercept = np.mean(y)
else:
slope, intercept, _, _, _ = linregress(x, y)
B_base = slope
C_base = intercept
# 2. x0_range: Range for the center of the peak
x_min, x_max = np.min(x), np.max(x)
data_range = x_max - x_min
# Robust calculation of x0 bounds and w bounds, handling small or zero data_range
if data_range < 1e-6: # If x values are almost constant
x0_bound_low = x_min - 1.0
x0_bound_high = x_max + 1.0
w_min_bound = 0.05 # Default for very narrow range
w_max_bound = 10.0 # Default for very narrow range
else:
x0_bound_low = x_min - data_range * 0.2 # Wider range for x0
x0_bound_high = x_max + data_range * 0.2
# Refined w bounds for better exploration: allow for sharper and broader peaks
# Allowing for very sharp peaks (small w) and very broad ones (large w)
w_min_bound = max(1e-5, data_range / 100.0)
w_max_bound = max(5.0, data_range * 5.0, 15.0) # Increased cap for w_max
x0_range_bounds = (x0_bound_low, x0_bound_high)
# 3. A_base: Amplitude of the "badness" peak (must be positive)
linear_pred = B_base * x + C_base
residuals_from_baseline = y - linear_pred
A_base = np.max(residuals_from_baseline) if np.max(residuals_from_baseline) > 0 else 0.01
# Cap A_base to a reasonable value and ensure a minimum positive amplitude
y_range = np.max(y) - np.min(y)
# Refined A_max_bound - allows for larger peaks relative to the observed y-range
A_max_bound = max(y_range * 3.0, 1.0)
A_base = min(A_base, A_max_bound * 0.75) if y_range > 0 else A_base
if A_base < 0.001: A_base = 0.001 # Ensure a minimum positive amplitude
# Define common bounds for all optimizations
bounds = [
(1e-6, A_max_bound), # A (amplitude) must be positive and within a reasonable max.
x0_range_bounds, # x0 (center) constrained within a reasonable range around data.
(w_min_bound, w_max_bound), # w (width) bounded by reasonable values.
(None, None), # B (slope) - no strong prior constraints.
(None, None) # C (intercept) - no strong prior constraints.
]
# --- Multiple Initializations Loop ---
num_inits = 70 # Increased number of different starting points for better exploration
# Heuristic for initial x0: point of max residual from linear fit
x0_peak_init_heuristic = np.mean(x) # Default if no clear peak
if x.size > 1 and np.max(residuals_from_baseline) > 1e-6:
x0_peak_init_heuristic = x[np.argmax(residuals_from_baseline)]
# Generate varied initial guesses for A, x0, w.
A_inits = np.unique(np.concatenate([
np.linspace(max(1e-6, A_base * 0.05), A_max_bound, num_inits // 4),
np.random.uniform(max(1e-6, A_base * 0.05), A_max_bound, num_inits // 4),
[A_base, max(1e-6, A_base * 0.5), A_max_bound * 0.1, A_max_bound * 0.5, A_max_bound] # Strategic points
]))
A_inits = A_inits[A_inits >= 1e-6] # Ensure A is positive
A_inits = A_inits[:num_inits] # Trim if too many unique values
x0_inits = np.unique(np.concatenate([
np.linspace(x0_bound_low, x0_bound_high, num_inits // 4),
np.random.uniform(x0_bound_low, x0_bound_high, num_inits // 4),
[x0_peak_init_heuristic, np.mean(x), x_min, x_max, (x_min + x_max) / 2.0] # Strategic points
]))
x0_inits = x0_inits[:num_inits]
# Use logspace for w_inits to cover a broader range effectively
# Handle cases where log_w_min >= log_w_max (e.g., if w_min_bound is very large, or w_max_bound is small)
log_w_min = np.log10(w_min_bound) if w_min_bound > 0 else -10.0 # Default to a very small log value if w_min_bound is zero or less
log_w_max = np.log10(w_max_bound) if w_max_bound > 0 else 10.0 # Default to a very large log value
# Ensure log_w_min < log_w_max for logspace to work
if log_w_min >= log_w_max: # If bounds are problematic, create a sensible default range
log_w_min = np.log10(max(1e-6, w_min_bound))
log_w_max = np.log10(max(1e-6, w_max_bound))
if log_w_min >= log_w_max: # If still an issue, make a tiny range
log_w_max = log_w_min + 1.0 # Create a small range for logspace
w_inits = np.unique(np.concatenate([
np.logspace(log_w_min, log_w_max, num_inits // 4),
10**np.random.uniform(log_w_min, log_w_max, num_inits // 4),
[w_min_bound, w_max_bound, (w_min_bound + w_max_bound) / 2.0, data_range / 2.0] # Strategic points, ensure data_range/2 is in range
]))
w_inits = w_inits[w_inits >= 1e-9] # Ensure w is positive
w_inits = w_inits[:num_inits]
# Iterate through initial parameter combinations
# Using a nested loop with modulo to cycle through combinations, ensuring all initial points are used
# and we get num_inits total attempts.
num_A = len(A_inits)
num_x0 = len(x0_inits)
num_w = len(w_inits)
actual_inits_to_try = num_inits # Use num_inits as the target for actual optimization runs
for i in range(actual_inits_to_try):
current_A_init = A_inits[i % num_A]
current_x0_init = x0_inits[i % num_x0]
current_w_init = w_inits[i % num_w]
initial_params = [current_A_init, current_x0_init, current_w_init, B_base, C_base]
# Ensure initial_params respect bounds before optimization to prevent ValueErrors
initial_params_clamped = []
for j, (lower, upper) in enumerate(bounds):
clamped_val = initial_params[j]
if lower is not None:
clamped_val = max(clamped_val, lower)
if upper is not None:
clamped_val = min(clamped_val, upper)
initial_params_clamped.append(clamped_val)
# Nudge 'w' slightly above its minimum bound if it's right on it, to avoid numerical instability
if initial_params_clamped[2] <= bounds[2][0]: # Check for <= to catch values exactly at the bound
initial_params_clamped[2] = bounds[2][0] + 1e-9
try:
result = minimize(objective, initial_params_clamped, method='L-BFGS-B', bounds=bounds,
options={'maxiter': 5000, 'ftol': 1e-9, 'gtol': 1e-9, 'disp': False})
# Check for successful convergence and finite parameters
if result.success and np.all(np.isfinite(result.x)) and result.fun < best_mse:
best_mse = result.fun
best_params = result.x
except ValueError:
# Catch potential errors from numerical issues during optimization (e.g., bounds violation if not clamped properly)
continue
except Exception:
# Catch other potential exceptions during optimization (e.g., singular matrix)
continue
# Fallback: If no successful optimization found after multiple attempts,
# perform one final robust optimization with a central initial guess.
if best_params is None:
# For debugging: print(f"Warning: Multiple initializations failed. Attempting robust fallback.")
fallback_A_init = A_base
fallback_x0_init = x0_peak_init_heuristic
# Use log-midpoint for fallback_w_init if log_w_min < log_w_max, otherwise use linear midpoint
if log_w_min < log_w_max:
fallback_w_init = 10**((log_w_min + log_w_max) / 2.0)
else:
fallback_w_init = (w_min_bound + w_max_bound) / 2.0
initial_params_fallback = [fallback_A_init, fallback_x0_init, fallback_w_init, B_base, C_base]
# Ensure fallback parameters respect bounds
initial_params_clamped_fallback = []
for j, (lower, upper) in enumerate(bounds):
clamped_val = initial_params_fallback[j]
if lower is not None:
clamped_val = max(clamped_val, lower)
if upper is not None:
clamped_val = min(clamped_val, upper)
initial_params_clamped_fallback.append(clamped_val)
# Nudge 'w' slightly above its minimum bound for fallback as well
if initial_params_clamped_fallback[2] <= bounds[2][0]:
initial_params_clamped_fallback[2] = bounds[2][0] + 1e-9
result_fallback = minimize(objective, initial_params_clamped_fallback, method='L-BFGS-B', bounds=bounds,
options={'maxiter': 5000, 'ftol': 1e-9, 'gtol': 1e-9, 'disp': False})
if result_fallback.success and np.all(np.isfinite(result_fallback.x)):
best_params = result_fallback.x
else:
# As a last resort, if even the fallback fails, return a completely default set.
# For debugging: print(f"Warning: Fallback optimization failed. Message: {result_fallback.message}. Returning clamped initial parameters.")
best_params = np.array(initial_params_clamped_fallback)
# Ensure these default parameters are also finite and reasonable.
if not np.all(np.isfinite(best_params)):
best_params = np.array([0.01, 0.0, 1.0, 0.0, 0.0]) # Absolute default if clamping somehow failed
return best_params
# EVOLVE-BLOCK-END