from __future__ import annotations
import math
from typing import Dict, List
# Discovered scaling law:
# Let x1 = log10(lr), x2 = log10(bsz), x3 = log10(data_size), x4 = log10(non_embedding_param_size).
# Predict lm_loss as a quadratic polynomial in these log-variables with interactions.
# The functional form is the same across groups; coefficients may differ by group.
# Coefficients fitted on the provided dataset for group "all_data" using
# Ridge regression on quadratic polynomial features of the log-variables.
# Keys correspond to polynomial feature names.
_COEFFICIENTS_BY_GROUP: Dict[str, Dict[str, float]] = {
"all_data": {
"1": 16.497915,
"log_lr": 0.266742,
"log_bsz": 0.907321,
"log_data_size": -2.112344,
"log_non_embedding_param_size": -0.308876,
# Quadratic terms
"log_lr^2": 0.148389,
"log_bsz^2": 0.126924,
"log_data_size^2": 0.134987,
"log_non_embedding_param_size^2": 0.077240,
# Pairwise interactions
"log_lr log_bsz": -0.081928,
"log_lr log_data_size": -0.024850,
"log_lr log_non_embedding_param_size": 0.121794,
"log_bsz log_data_size": -0.123098,
"log_bsz log_non_embedding_param_size": -0.053240,
"log_data_size log_non_embedding_param_size": -0.082462,
}
}
# If an unknown group is provided, fall back to this group name
_DEFAULT_GROUP = "all_data"
def _predict_one(d: Dict[str, float], coeffs: Dict[str, float]) -> float:
# Extract and validate input variables
try:
lr = float(d["lr"])
bsz = float(d["bsz"])
data_size = float(d["data_size"])
non_emb_params = float(d["non_embedding_param_size"])
except KeyError as e:
raise KeyError(f"Missing required key: {e.args[0]}")
if lr <= 0 or bsz <= 0 or data_size <= 0 or non_emb_params <= 0:
raise ValueError("All inputs must be positive to compute logarithms.")
# Log10 transform
log_lr = math.log10(lr)
log_bsz = math.log10(bsz)
log_data_size = math.log10(data_size)
log_non_emb = math.log10(non_emb_params)
# Compute polynomial terms
terms = {
"1": 1.0,
"log_lr": log_lr,
"log_bsz": log_bsz,
"log_data_size": log_data_size,
"log_non_embedding_param_size": log_non_emb,
"log_lr^2": log_lr * log_lr,
"log_bsz^2": log_bsz * log_bsz,
"log_data_size^2": log_data_size * log_data_size,
"log_non_embedding_param_size^2": log_non_emb * log_non_emb,
"log_lr log_bsz": log_lr * log_bsz,
"log_lr log_data_size": log_lr * log_data_size,
"log_lr log_non_embedding_param_size": log_lr * log_non_emb,
"log_bsz log_data_size": log_bsz * log_data_size,
"log_bsz log_non_embedding_param_size": log_bsz * log_non_emb,
"log_data_size log_non_embedding_param_size": log_data_size * log_non_emb,
}
# Weighted sum
y = 0.0
for name, val in terms.items():
coef = coeffs.get(name, 0.0)
y += coef * val
return y
def law(input_data: List[Dict[str, float]], group: str) -> List[Dict[str, float]]:
"""
Predicts output variables based on input variables according to a discovered scaling law.
Args:
input_data: A list of dictionaries, where each dictionary is a single data
point containing input variable names as keys and their
corresponding values.
group: The name of the experimental group for which to make predictions.
The functional form of the law must be the same for all groups,
but the constant parameters/coefficients can differ per group.
Returns:
A list of dictionaries, corresponding to the input_data list, with each
dictionary containing the predicted output variable(s).
"""
# Choose coefficients for the group, or fall back.
coeffs = _COEFFICIENTS_BY_GROUP.get(group)
if coeffs is None:
coeffs = _COEFFICIENTS_BY_GROUP[_DEFAULT_GROUP]
outputs: List[Dict[str, float]] = []
for d in input_data:
y = _predict_one(d, coeffs)
outputs.append({"lm_loss": float(y)})
return outputs