from __future__ import annotations
import math
from typing import Dict, List
# Coefficients learned on the provided dataset for group 'all_data'.
# Feature order:
# [1, x1, x2, x3, x4, x1^2, x2^2, x3^2, x4^2, x1*x2, x1*x3, x1*x4, x2*x3, x2*x4, x3*x4]
_COEFS_BY_GROUP: Dict[str, List[float]] = {
"all_data": [
15.408655757208578,
0.1479904624134041,
0.925576816730592,
-2.0155807017749745,
-0.21074365992568728,
0.1445807182504939,
0.12570943660274597,
0.13477282782648167,
0.07811997175906828,
-0.0778445730877946,
-0.02359921758963033,
0.1304365497600781,
-0.12590176704259384,
-0.050041748839094104,
-0.09213648452069143,
]
}
def _predict_single(sample: Dict[str, float], coefs: List[float]) -> float:
# Extract inputs
lr = float(sample.get("lr", 0.0))
bsz = float(sample.get("bsz", 0.0))
data_size = float(sample.get("data_size", 0.0))
non_embed_params = float(sample.get("non_embedding_param_size", 0.0))
# Guard against non-positive values before log
eps = 1e-300
x1 = math.log10(max(lr, eps))
x2 = math.log10(max(bsz, eps))
x3 = math.log10(max(data_size, eps))
x4 = math.log10(max(non_embed_params, eps))
# Build feature vector in the fixed order
feats = [
1.0,
x1,
x2,
x3,
x4,
x1 * x1,
x2 * x2,
x3 * x3,
x4 * x4,
x1 * x2,
x1 * x3,
x1 * x4,
x2 * x3,
x2 * x4,
x3 * x4,
]
# Linear combination
pred = 0.0
for f, c in zip(feats, coefs):
pred += f * c
return float(pred)
def law(input_data: list[dict[str, float]], group: str) -> list[dict[str, float]]:
"""
Predicts output variables based on input variables according to a discovered scaling law.
Args:
input_data: A list of dictionaries, where each dictionary is a single data
point containing input variable names as keys and their
corresponding values.
group: The name of the experimental group for which to make predictions.
The functional form of the law must be the same for all groups,
but the constant parameters/coefficients can differ per group.
Returns:
A list of dictionaries, corresponding to the input_data list, with each
dictionary containing the predicted output variable(s).
"""
# Select coefficient set; default to 'all_data' when group is unknown
coefs = _COEFS_BY_GROUP.get(group, _COEFS_BY_GROUP["all_data"])
outputs: List[Dict[str, float]] = []
for row in input_data:
lm_loss = _predict_single(row, coefs)
outputs.append({"lm_loss": lm_loss})
return outputs