from __future__ import annotations
from math import log10
from typing import Dict, List
# Quadratic-in-log scaling law with key interactions for language modeling loss.
#
# Let ld = log10(data_size), lp = log10(non_embedding_param_size),
# llr = log10(lr), lb = log10(bsz).
#
# lm_loss = c0 \
# + c1 * ld \
# + c2 * lp \
# + c3 * llr \
# + c4 * lb \
# + c5 * (llr)**2 \
# + c6 * ld * lp \
# + c7 * (ld)**2 \
# + c8 * (lp)**2 \
# + c9 * llr * ld \
# + c10 * llr * lp \
# + c11 * lb * ld \
# + c12 * lb * lp \
# + c13 * lb * llr \
# + c14 * (lb)**2
#
# Coefficients are fitted per experimental group. If an unknown group is
# requested, we fall back to the 'all_data' coefficients.
_COEFFS_BY_GROUP: Dict[str, List[float]] = {
# Order:
# [c0, c1(ld), c2(lp), c3(llr), c4(lb), c5(llr^2), c6(ld*lp), c7(ld^2), c8(lp^2),
# c9(llr*ld), c10(llr*lp), c11(lb*ld), c12(lb*lp), c13(lb*llr), c14(lb^2)]
# Fitted on the provided dataset (/app/data)
# Using least squares on 2702 points, R^2 ≈ 0.977 (5-fold CV ≈ 0.976)
"all_data": [
1.681388886e01, # c0
-2.14226036e00, # c1 (ld)
-3.48992730e-01, # c2 (lp)
2.62425420e-01, # c3 (llr)
9.04917660e-01, # c4 (lb)
1.48530750e-01, # c5 (llr^2)
-8.06989200e-02, # c6 (ld*lp)
1.35736300e-01, # c7 (ld^2)
7.86298100e-02, # c8 (lp^2)
-2.47657100e-02, # c9 (llr*ld)
1.22298120e-01, # c10 (llr*lp)
-1.23088430e-01, # c11 (lb*ld)
-5.30003800e-02, # c12 (lb*lp)
-8.19605000e-02, # c13 (lb*llr)
1.26955570e-01, # c14 (lb^2)
],
}
# Default/fallback coefficients
_DEFAULT_GROUP = "all_data"
def _safe_log10(x: float) -> float:
"""Compute log10 with a tiny positive floor for numerical safety.
The dataset and expected inputs should be strictly positive for all variables,
but we guard against accidental non-positive inputs by flooring to a tiny
positive value to avoid math domain errors and keep the function robust.
"""
# Floor near double-precision minimum, but not too extreme to avoid inf
tiny = 1e-300
if not isinstance(x, (int, float)):
raise TypeError(f"Expected a number, got {type(x)}")
if x <= 0 or x != x: # also handles NaN
x = tiny
return log10(x)
def _predict_row(row: Dict[str, float], coeffs: List[float]) -> float:
ld = _safe_log10(float(row["data_size"]))
lp = _safe_log10(float(row["non_embedding_param_size"]))
llr = _safe_log10(float(row["lr"]))
lb = _safe_log10(float(row["bsz"]))
(
c0, c1, c2, c3, c4,
c5, c6, c7, c8,
c9, c10, c11, c12, c13, c14,
) = coeffs
y = (
c0
+ c1 * ld
+ c2 * lp
+ c3 * llr
+ c4 * lb
+ c5 * (llr ** 2)
+ c6 * ld * lp
+ c7 * (ld ** 2)
+ c8 * (lp ** 2)
+ c9 * llr * ld
+ c10 * llr * lp
+ c11 * lb * ld
+ c12 * lb * lp
+ c13 * lb * llr
+ c14 * (lb ** 2)
)
return float(y)
def law(input_data: list[dict[str, float]], group: str) -> list[dict[str, float]]:
"""
Predicts output variables based on input variables according to a discovered scaling law.
Args:
input_data: A list of dictionaries, where each dictionary is a single data
point containing input variable names as keys and their
corresponding values. Required keys per dict:
- 'lr'
- 'bsz'
- 'data_size'
- 'non_embedding_param_size'
group: The name of the experimental group for which to make predictions.
The functional form of the law is the same for all groups, but
the constant parameters/coefficients can differ per group.
Returns:
A list of dictionaries with one key:
- 'lm_loss': the predicted language modeling loss.
"""
coeffs = _COEFFS_BY_GROUP.get(group, _COEFFS_BY_GROUP[_DEFAULT_GROUP])
outputs: List[Dict[str, float]] = []
for row in input_data:
y = _predict_row(row, coeffs)
outputs.append({"lm_loss": y})
return outputs