from __future__ import annotations
import math
from typing import Dict, List
def _predict_lm_loss(x: Dict[str, float], coef: Dict[str, float]) -> float:
lr = float(x.get("lr", 0.0))
bsz = float(x.get("bsz", 0.0))
data_size = float(x.get("data_size", 0.0))
non_emb = float(x.get("non_embedding_param_size", 0.0))
# Guard against non-positive inputs for log
eps = 1e-16
ll = math.log10(max(lr, eps))
lb = math.log10(max(bsz, eps))
ld = math.log10(max(data_size, eps))
lp = math.log10(max(non_emb, eps))
y = (
coef["intercept"]
+ coef["log_lr"] * ll
+ coef["log_bsz"] * lb
+ coef["log_lr2"] * (ll * ll)
+ coef["log_bsz2"] * (lb * lb)
+ coef["lr_bsz"] * (ll * lb)
+ coef["log_data"] * ld
+ coef["log_params"] * lp
)
return float(y)
# Coefficients discovered via log-polynomial regression on the provided dataset.
# Functional form is the same for all groups; coefficients may differ per group.
COEFFS_BY_GROUP: Dict[str, Dict[str, float]] = {
"all_data": {
"intercept": 9.919174347950008,
"log_lr": 0.934534343690493,
"log_bsz": -0.6179383648150774,
"log_lr2": 0.1368417220658123,
"log_bsz2": 0.09978735298702487,
"lr_bsz": -0.060204456752825174,
"log_data": -0.28033584602209644,
"log_params": -0.30419462596816593,
}
}
# Fallback: if an unknown group is requested, use the closest available baseline.
DEFAULT_GROUP = "all_data"
def law(input_data: List[Dict[str, float]], group: str) -> List[Dict[str, float]]:
"""
Predicts output variables based on input variables according to a discovered scaling law.
Args:
input_data: A list of dictionaries, where each dictionary is a single data
point containing input variable names as keys and their
corresponding values.
group: The name of the experimental group for which to make predictions.
The functional form of the law is the same for all groups,
but the coefficients can differ per group.
Returns:
A list of dictionaries, corresponding to the input_data list, with each
dictionary containing the predicted output variable(s).
"""
coef = COEFFS_BY_GROUP.get(group, COEFFS_BY_GROUP[DEFAULT_GROUP])
out = []
for row in input_data:
y = _predict_lm_loss(row, coef)
out.append({"lm_loss": y})
return out