# EVOLVE-BLOCK-START
import numpy as np
from scipy.optimize import least_squares
def scaling_law_func(data_points, params):
"""
Predicts LM loss based on learning rate, batch size, data size, and non-embedding parameter size.
The model is of the form:
Loss = L_0 + c_lr1 * lr^e_lr1 + c_lr2 * lr^e_lr2 + c_bsz * bsz^e_bsz + c_data * data_size^e_data + c_params * non_embedding_param_size^e_params
This model uses two learning rate terms to capture a U-shaped or more complex relationship,
where one term typically models the benefit of increasing LR (e_lr1 < 0) and the other
the detriment (e_lr2 > 0).
Args:
data_points (np.ndarray): (N, 4) array with columns [lr, bsz, data_size, non_embedding_param_size].
params (np.ndarray): Array of model parameters. Can be (P,) for a single model or (T, P) for multiple.
Expected P=11: [L_0, c_lr1, e_lr1, c_lr2, e_lr2, c_bsz, e_bsz, c_data, e_data, c_params, e_params].
Returns:
np.ndarray: Predicted lm loss values. Shape (N,) if params is (P,), or (N, T) if params is (T, P).
"""
X = np.atleast_2d(np.asarray(data_points)) # (N, F)
# Ensure all inputs are positive before log/power to prevent numerical issues
X = np.maximum(X, 1e-10)
params_arr = np.asarray(params)
# Adapt to the original framework's potential (T, P) parameter passing
if params_arr.ndim == 1:
params_arr = params_arr[None, :] # Make it (1, P)
T, P = params_arr.shape # T: number of parameter sets, P: number of parameters per set
# Expected number of parameters for this specific model structure
# 1 (L0) + 4 (lr terms) + 2 (bsz) + 2 (data_size) + 2 (params_size) = 11
EXPECTED_P = 11
if P != EXPECTED_P:
if P > EXPECTED_P:
# If more parameters are passed than expected, use only the first EXPECTED_P
params_arr = params_arr[:, :EXPECTED_P]
P = EXPECTED_P
else:
# If fewer parameters are passed, it's an error in model definition or parameter passing.
raise ValueError(f"Expected {EXPECTED_P} parameters per set for the scaling law model, but received {P}. "
"Please check the number of parameters defined in the model structure.")
# Extract parameters for each parameter set (T sets)
L0_arr = params_arr[:, 0]
c_lr1_arr, e_lr1_arr = params_arr[:, 1], params_arr[:, 2]
c_lr2_arr, e_lr2_arr = params_arr[:, 3], params_arr[:, 4]
c_bsz_arr, e_bsz_arr = params_arr[:, 5], params_arr[:, 6]
c_data_arr, e_data_arr = params_arr[:, 7], params_arr[:, 8]
c_params_arr, e_params_arr = params_arr[:, 9], params_arr[:, 10]
# Ensure coefficients are non-negative for power laws, robustifying against
# potential floating point issues or edge cases in optimization.
c_lr1_arr = np.maximum(c_lr1_arr, 1e-10)
c_lr2_arr = np.maximum(c_lr2_arr, 1e-10)
c_bsz_arr = np.maximum(c_bsz_arr, 1e-10)
c_data_arr = np.maximum(c_data_arr, 1e-10)
c_params_arr = np.maximum(c_params_arr, 1e-10)
# Calculate individual contributions using log-space for numerical stability,
# then exponentiate. This is generally preferred over direct X**e for robustness
# with arbitrary real exponents and very small base values.
# X[:, feature_idx][:, None] makes it (N, 1) for broadcasting against (1, T) parameter arrays
log_X_lr = np.log(X[:, 0][:, None])
log_X_bsz = np.log(X[:, 1][:, None])
log_X_data = np.log(X[:, 2][:, None])
log_X_params = np.log(X[:, 3][:, None])
# Learning rate terms (U-shaped contribution)
term_lr1 = c_lr1_arr[None, :] * np.exp(e_lr1_arr[None, :] * log_X_lr)
term_lr2 = c_lr2_arr[None, :] * np.exp(e_lr2_arr[None, :] * log_X_lr)
term_lr = term_lr1 + term_lr2
# Other terms
term_bsz = c_bsz_arr[None, :] * np.exp(e_bsz_arr[None, :] * log_X_bsz)
term_data = c_data_arr[None, :] * np.exp(e_data_arr[None, :] * log_X_data)
term_params = c_params_arr[None, :] * np.exp(e_params_arr[None, :] * log_X_params)
# Sum all contributions
pred = L0_arr[None, :] + term_lr + term_bsz + term_data + term_params
# Ensure predictions are non-negative, as loss cannot be negative.
# This also helps clip any numerically unstable negative predictions that might arise.
pred = np.maximum(pred, 0.0)
# If only one set of parameters was passed (T=1), return a 1D array (N,)
return pred[:, 0] if T == 1 else pred
def fit_scaling_law(data_points, loss_values):
"""
Fits the scaling law function to the given data points and loss values.
Args:
data_points (np.ndarray): (N, 4) array with columns [lr, bsz, data_size, non_embedding_param_size].
loss_values (np.ndarray): (N,) array of corresponding lm loss values.
Returns:
np.ndarray: Optimized parameters (P,) for the scaling law function.
[L_0, c_lr1, e_lr1, c_lr2, e_lr2, c_bsz, e_bsz, c_data, e_data, c_params, e_params].
"""
X = np.atleast_2d(np.asarray(data_points))
y = np.asarray(loss_values)
def residuals(params, X, y):
pred = scaling_law_func(X, params)
res = pred - y
# Robustly handle NaN/Inf predictions.
# Assign a large finite value (1e10) with the correct sign to problematic residuals
# to strongly penalize these regions during optimization.
problematic_indices = ~np.isfinite(res)
if np.any(problematic_indices):
# Use np.where to ensure that even if res is NaN, it gets a sign (e.g., from a small positive value)
# This makes the gradient more meaningful than just NaN.
res[problematic_indices] = 1e10 * np.sign(np.where(np.isfinite(res), res, 1e-6))[problematic_indices]
return res
# Total number of parameters for the new model (L0 + 2*LR + 2*BSZ + 2*Data + 2*Params = 11)
P = 11
# Initial guess for parameters: [L_0, c_lr1, e_lr1, c_lr2, e_lr2, c_bsz, e_bsz, c_data, e_data, c_params, e_params]
# These initial guesses are informed by typical LLM scaling laws and data ranges,
# with a focus on capturing the U-shaped learning rate behavior.
initial_params = np.array([
np.min(y) * 0.9, # L_0: Irreducible loss, slightly below min observed loss
1e-4, -1.0, # c_lr1, e_lr1: For the decreasing loss part with increasing LR (e.g., 1/LR)
1e3, 1.0, # c_lr2, e_lr2: For the increasing loss part with increasing LR (e.g., LR)
0.1, 0.1, # c_bsz, e_bsz: Small effect, potentially slightly positive exponent for batch size
10.0, -0.1, # c_data, e_data: Data typically reduces loss (negative exponent)
5.0, -0.1 # c_params, e_params: Parameters typically reduce loss (negative exponent)
])
# Refined bounds for parameters to guide the optimizer and ensure physical realism.
# Coefficients (c_i) are generally positive. Exponents (e_i) are constrained based on expected effects.
lower_bounds = np.array([
0.0, # L_0: Irreducible loss must be non-negative
1e-10, -3.0, # c_lr1 (positive), e_lr1 (negative for 1/lr effect)
1e-10, 0.01, # c_lr2 (positive), e_lr2 (positive for lr effect)
1e-10, -1.0, # c_bsz (positive), e_bsz (can be negative or positive, but not too extreme)
1e-10, -1.0, # c_data (positive), e_data (negative or zero, increasing data should not increase loss)
1e-10, -1.0 # c_params (positive), e_params (negative or zero, increasing params should not increase loss)
])
upper_bounds = np.array([
np.max(y) * 1.5, # L_0: Cannot exceed max observed loss significantly
1e2, -0.01, # c_lr1, e_lr1 (must be negative, e.g., <= -0.01)
1e5, 3.0, # c_lr2, e_lr2 (must be positive, e.g., >= 0.01)
1e3, 1.0, # c_bsz, e_bsz
1e4, 0.0, # c_data, e_data (<= 0)
1e4, 0.0 # c_params, e_params (<= 0)
])
# Use 'trf' (Trust Region Reflective) method, which handles bounds effectively and is robust for non-linear least squares.
# verbose=0 suppresses convergence messages.
# max_nfev increased to allow more iterations for complex landscapes.
# ftol and xtol tightened for better convergence precision, to help with the more complex model.
result = least_squares(residuals, initial_params, args=(X, y),
bounds=(lower_bounds, upper_bounds),
method='trf', verbose=0, max_nfev=5000, ftol=1e-8, xtol=1e-8)
if result.success:
return result.x
else:
# If optimization fails, return the initial_params as a robust fallback.
print("Warning: least_squares optimization failed. Returning initial parameters.")
return initial_params
# EVOLVE-BLOCK-END