import numpy as np
from scipy.optimize import minimize
def scaling_law_func(data_points, params):
"""
Predict unigram-normalized loss (Lossu) via a two-term
3D power‐law with synergy in the main term plus an additive constant:
Lossu = C0 * exp[-(a·logP + b·logD + g·logV)]
+ C1 * exp[-(h·logD)]
+ C2
Inputs:
data_points: array of shape (N,3) columns = [P_non_vocab, Vocab_size, Num_characters]
params: length-7 array = [C0, C1, C2, a, b, g, h]
Returns:
preds: length-N array of predicted Lossu
"""
X = np.atleast_2d(data_points)
P = X[:, 0]
V = X[:, 1]
D = X[:, 2]
C0, C1, C2, a, b, g, h = params
lp = np.log(P)
lv = np.log(V)
ld = np.log(D)
term1 = np.exp(-(a * lp + b * ld + g * lv))
term2 = np.exp(-h * ld)
return C0 * term1 + C1 * term2 + C2
def fit_scaling_law(data_points, loss_values):
"""
Fit the 7‐parameter model by:
1) optimizing exponents (a, b, g, h) in log‐space (ensures positivity)
while solving [C0,C1,C2] via linear least squares at each step;
2) recovering an initial 7‐vector estimate;
3) performing a bounded L-BFGS-B refinement on all 7 parameters
to further minimize MSE.
Returns:
params_opt: length-7 array [C0, C1, C2, a, b, g, h]
"""
X = np.atleast_2d(data_points)
y = np.asarray(loss_values).ravel()
P, V, D = X[:, 0], X[:, 1], X[:, 2]
lp, lv, ld = np.log(P), np.log(V), np.log(D)
# 1) Optimize exponents in log-space
def mse_exp(log_exps):
# a, b, g, h > 0 by exponentiating
a, b, g, h = np.exp(log_exps)
phi1 = np.exp(-(a * lp + b * ld + g * lv))
phi2 = np.exp(-h * ld)
M = np.vstack((phi1, phi2, np.ones_like(y))).T
coeffs, *_ = np.linalg.lstsq(M, y, rcond=None)
y_pred = M.dot(coeffs)
return np.mean((y_pred - y) ** 2)
init_log_exps = np.log([0.5, 0.5, 0.5, 0.5])
res1 = minimize(mse_exp, init_log_exps, method='L-BFGS-B')
a0, b0, g0, h0 = np.exp(res1.x)
# 2) Solve for C0, C1, C2 given the exponents
phi1 = np.exp(-(a0 * lp + b0 * ld + g0 * lv))
phi2 = np.exp(-h0 * ld)
M_lin = np.vstack((phi1, phi2, np.ones_like(y))).T
C0_0, C1_0, C2_0 = np.linalg.lstsq(M_lin, y, rcond=None)[0]
initial_params = np.array([C0_0, C1_0, C2_0, a0, b0, g0, h0])
# 3) Final refinement: optimize all 7 parameters together
def mse_all(params):
y_pred = scaling_law_func(X, params)
return np.mean((y_pred - y) ** 2)
# Bounds: exponents >= 1e-12, coefficients free
bounds = [(None, None)] * 3 + [(1e-12, None)] * 4
res2 = minimize(mse_all, initial_params,
method='L-BFGS-B',
bounds=bounds)
return res2.x