# EVOLVE-BLOCK-START
"""
Scaling law discovery for LLM finetuning scenarios.
Models the scaling behavior as a smooth transition between two linear regimes in log-flops space.
This functional form (Sigmoid-weighted Broken Stick) captures monotonic, U-shaped, and inverted U-shaped patterns.
Uses 6 parameters: [slope1, bias1, slope2, bias2, transition_point, sharpness].
Improved fitting via input centering and quadratic initialization to robustly detect U-shapes.
"""
import numpy as np
from scipy.optimize import minimize
def scaling_law_func(data_points, params):
# data_points: (N, 1) array of log_flops
# params: (6,) or (T, 6) array of parameters [w1, b1, w2, b2, m, s]
X = np.atleast_2d(np.asarray(data_points))
x = X[:, 0:1] # (N, 1)
params = np.asarray(params)
if params.ndim == 1:
params = params[None, :] # (1, 6)
# Transpose to (6, T) for broadcasting
p = params.T
w1 = p[0:1, :]
b1 = p[1:2, :]
w2 = p[2:3, :]
b2 = p[3:4, :]
m = p[4:5, :]
s = p[5:6, :]
# Sigmoid transition
# z = s * (x - m)
# Clip z to prevent overflow/underflow in exp, ensuring numerical stability
z = s * (x - m)
z = np.clip(z, -50.0, 50.0)
sig = 1.0 / (1.0 + np.exp(-z))
# Linear regimes
y1 = w1 * x + b1
y2 = w2 * x + b2
# Combined prediction: (1 - sig) * y1 + sig * y2
pred = (1.0 - sig) * y1 + sig * y2
return pred[:, 0] if pred.shape[1] == 1 else pred
def fit_scaling_law(data_points, loss_values):
X = np.atleast_2d(np.asarray(data_points))
x_flat = X[:, 0]
y = np.asarray(loss_values)
if y.ndim == 1:
y_2d = y[:, None]
else:
y_2d = y
N, T = y_2d.shape
# Center input data for better optimization conditioning
# This avoids interaction between slope and bias during fitting
x_mean = np.mean(x_flat)
x_centered = x_flat - x_mean
x_min, x_max = np.min(x_centered), np.max(x_centered)
best_params_list = []
for t in range(T):
yt = y_2d[:, t]
# Objective function defined on centered data
def objective(p):
# p: [w1, b1, w2, b2, m, s] (in centered space)
z = p[5] * (x_centered - p[4])
z = np.clip(z, -50.0, 50.0)
sig = 1.0 / (1.0 + np.exp(-z))
y1 = p[0] * x_centered + p[1]
y2 = p[2] * x_centered + p[3]
pred = (1.0 - sig) * y1 + sig * y2
return np.mean((pred - yt)**2)
# Candidate Initializations
candidates = []
# 1. Quadratic Initialization (Good for U-shape / Inverted U)
try:
# Fit y ~ c2*x^2 + c1*x + c0
c = np.polyfit(x_centered, yt, 2)
# Vertex m = -c1 / 2c2
if abs(c[0]) > 1e-5:
m_quad = -c[1] / (2 * c[0])
m_quad = np.clip(m_quad, x_min, x_max)
else:
m_quad = 0.0
# Slopes at boundaries
w1_q = 2 * c[0] * x_min + c[1]
w2_q = 2 * c[0] * x_max + c[1]
# Intercepts at boundaries (y = wx + b => b = y - wx)
y_at_min = c[0]*x_min**2 + c[1]*x_min + c[2]
y_at_max = c[0]*x_max**2 + c[1]*x_max + c[2]
b1_q = y_at_min - w1_q * x_min
b2_q = y_at_max - w2_q * x_max
candidates.append([w1_q, b1_q, w2_q, b2_q, m_quad, 5.0])
except:
pass
# 2. Split Initialization (Good for V-shape or Broken Stick)
# Try splitting at 33% and 66%
for pct in [33, 66]:
split_x = np.percentile(x_centered, pct)
mask_l = x_centered <= split_x
mask_r = x_centered > split_x
if np.sum(mask_l) >= 2:
wl, bl = np.polyfit(x_centered[mask_l], yt[mask_l], 1)
else:
wl, bl = 0.0, np.mean(yt)
if np.sum(mask_r) >= 2:
wr, br = np.polyfit(x_centered[mask_r], yt[mask_r], 1)
else:
wr, br = 0.0, np.mean(yt)
candidates.append([wl, bl, wr, br, split_x, 5.0])
# 3. Linear Fallback
try:
wl, bl = np.polyfit(x_centered, yt, 1)
candidates.append([wl, bl, wl, bl, 0.0, 1.0])
except:
pass
# Optimization
best_loss = np.inf
# Default fallback
best_p_centered = np.array([0., np.mean(yt), 0., np.mean(yt), 0., 1.])
# Bounds: m in range, s positive
bnds = [
(None, None), (None, None),
(None, None), (None, None),
(x_min - 0.5, x_max + 0.5),
(0.1, 100.0)
]
for p0 in candidates:
try:
res = minimize(objective, p0, method='L-BFGS-B', bounds=bnds, tol=1e-6)
if res.fun < best_loss:
best_loss = res.fun
best_p_centered = res.x
except:
continue
# Denormalize parameters
w1_c, b1_c, w2_c, b2_c, m_c, s_c = best_p_centered
# Transform back to original x space
# x_centered = x - x_mean
# regime 1: w1 * (x - x_mean) + b1 = w1*x + (b1 - w1*x_mean)
# sigmoid: s * (x - x_mean - m_c) = s * (x - (m_c + x_mean))
w1 = w1_c
b1 = b1_c - w1_c * x_mean
w2 = w2_c
b2 = b2_c - w2_c * x_mean
m = m_c + x_mean
s = s_c
best_params_list.append([w1, b1, w2, b2, m, s])
return np.array(best_params_list)[0] if T==1 else np.array(best_params_list)
# EVOLVE-BLOCK-END