import numpy as np
from scipy.optimize import least_squares
def scaling_law_func(data_points, params):
# data_points: (N, 3) array [unique_tokens, params, tokens]
# params: Array of 7 parameters [E, A, alpha, B, beta, C, delta]
# Ensure inputs are at least 2D
X = np.atleast_2d(np.asarray(data_points))
P = np.asarray(params)
# Handle batch of parameters vs single set
squeeze_output = False
if P.ndim == 1:
P = P[None, :]
squeeze_output = True
# Extract features
unique_tokens = X[:, 0:1]
model_params = X[:, 1:2]
tokens = X[:, 2:3]
# Normalization Constants (Geometric means of domain)
# N: ~3e8, D: ~3e10
# Centers the inputs near 1.0 to help the optimizer
N_SCALE = 3e8
D_SCALE = 3e10
n = model_params / N_SCALE
d = tokens / D_SCALE
# Repetition ratio r = D / U
r = tokens / (unique_tokens + 1e-9)
# Extract Parameters (enforce positive via abs)
E = np.abs(P[:, 0])
A = np.abs(P[:, 1])
alpha = np.abs(P[:, 2])
B = np.abs(P[:, 3])
beta = np.abs(P[:, 4])
C = np.abs(P[:, 5])
delta = np.abs(P[:, 6])
# Functional Form:
# L = E + A*N^-alpha + B*D^-beta + C*R^delta
# Standard Chinchilla terms + Repetition penalty
# Broadcasting: (N, 1) and (1, K) -> (N, K)
# Add small epsilons to bases to prevent NaN gradients/values
term_model = A[None, :] * ((n + 1e-12) ** -alpha[None, :])
term_data = B[None, :] * ((d + 1e-12) ** -beta[None, :])
term_rep = C[None, :] * ((r + 1e-12) ** delta[None, :])
pred = E[None, :] + term_model + term_data + term_rep
if squeeze_output:
return pred[:, 0]
return pred
def fit_scaling_law(data_points, loss_values):
X = np.asarray(data_points)
y = np.asarray(loss_values)
y_min = np.min(y)
# Optimization in log-space for coefficients A, B, C
# This handles the varying scales (e.g. C might be 1e-5 while A is 1.0)
# p_opt layout: [E, logA, alpha, logB, beta, logC, delta]
def residuals(p_log):
# Convert log-space params back to linear for function evaluation
p_lin = np.array([
p_log[0], # E
np.exp(p_log[1]), # A
p_log[2], # alpha
np.exp(p_log[3]), # B
p_log[4], # beta
np.exp(p_log[5]), # C
p_log[6] # delta
])
return scaling_law_func(X, p_lin) - y
# Heuristic Initialization Strategy
# We provide guesses for linear parameters, then convert to log space
# [E, A, alpha, B, beta, C, delta]
starts = [
# Balanced Chinchilla
[1.6, 1.0, 0.33, 1.0, 0.33, 0.001, 1.0],
# Steep scaling
[1.5, 5.0, 0.5, 5.0, 0.5, 1e-4, 0.5],
# High repetition penalty
[1.8, 0.5, 0.3, 0.5, 0.3, 0.1, 2.0],
# Data limited
[1.6, 0.1, 0.1, 2.0, 0.6, 0.01, 1.0],
# Model limited
[1.6, 2.0, 0.6, 0.1, 0.1, 0.01, 1.0],
# Conservative / Flat
[y_min*0.9, 1.0, 0.1, 1.0, 0.1, 0.0, 0.0],
]
# Bounds for optimization variables
# E: [0.5, y_min] - E must be lower than any observed loss
# logA, logB, logC: [-inf, inf]
# alpha, beta: [0, 3]
# delta: [0, 10]
# Cap E slightly below y_min to force scaling terms to explain variance
upper_E = max(0.9, y_min - 0.01)
lower_bounds = [0.5, -np.inf, 0.0, -np.inf, 0.0, -np.inf, 0.0]
upper_bounds = [upper_E, np.inf, 3.0, np.inf, 3.0, np.inf, 10.0]
best_res = None
best_cost = float('inf')
for s in starts:
# Convert start to log space
p0 = np.array(s)
# Ensure start E is valid
p0[0] = min(p0[0], upper_E - 0.05)
p_log_start = np.array([
p0[0],
np.log(p0[1] + 1e-16),
p0[2],
np.log(p0[3] + 1e-16),
p0[4],
np.log(p0[5] + 1e-16),
p0[6]
])
try:
# Use soft_l1 loss to be robust against outliers
# f_scale=0.1 means residuals < 0.1 are treated as squared error, > 0.1 as linear
res = least_squares(residuals, p_log_start,
bounds=(lower_bounds, upper_bounds),
method='trf',
loss='soft_l1',
f_scale=0.1,
max_nfev=1000)
if res.cost < best_cost:
best_cost = res.cost
best_res = res
except Exception:
continue
if best_res is not None:
p = best_res.x
return np.array([
p[0], np.exp(p[1]), p[2], np.exp(p[3]), p[4], np.exp(p[5]), p[6]
])
# Fallback
return np.array(starts[0])