from typing import List, Dict
import math
# Fitted coefficients per group for the scaling law:
# loss = L0 + a * params^(-alpha) + b * tokens^(-beta) + c * ln(unique_tokens)
#
# Notes:
# - Coefficients below were obtained via non-linear least squares on the provided dataset.
# - If an unknown group is requested, we fall back to the 'all_data' coefficients.
_COEFFS = {
# Trained from the dataset at /app/data (161 points, single group 'all_data')
"all_data": {
"L0": 5.314158928164251,
"a": 4163.742173986624,
"alpha": 0.4910050761229603,
"b": 109180.20697694572,
"beta": 0.5637776884040872,
"c": -0.11944428211525198,
}
}
def _predict_single(x: Dict[str, float], k: Dict[str, float]) -> float:
# Safeguards for domain constraints
params = max(float(x.get("params", 0.0)), 1e-12)
tokens = max(float(x.get("tokens", 0.0)), 1e-12)
unique_tokens = max(float(x.get("unique_tokens", 0.0)), 1.0)
return (
k["L0"]
+ k["a"] * (params ** (-k["alpha"]))
+ k["b"] * (tokens ** (-k["beta"]))
+ k["c"] * math.log(unique_tokens)
)
def law(input_data: List[Dict[str, float]], group: str) -> List[Dict[str, float]]:
"""
Predicts the final validation loss ('loss') for language model pre-training
given parameter count ('params'), total training tokens ('tokens'), and the
number of unique tokens in the dataset ('unique_tokens').
Functional form (shared across groups):
loss = L0 + a * params^(-alpha) + b * tokens^(-beta) + c * ln(unique_tokens)
The coefficients (L0, a, alpha, b, beta, c) are group-specific. If the
provided group is unknown, this function falls back to 'all_data'.
Args:
input_data: List of dicts; each must contain 'params', 'tokens',
and 'unique_tokens' (floats).
group: Name of the experimental group.
Returns:
List of dicts with a single key 'loss' containing the prediction.
"""
if not isinstance(input_data, list):
raise TypeError("input_data must be a list of dictionaries")
coeffs = _COEFFS.get(group, _COEFFS["all_data"]) # fallback to all_data
out: List[Dict[str, float]] = []
for row in input_data:
y = _predict_single(row, coeffs)
out.append({"loss": float(y)})
return out