from __future__ import annotations
import math
from typing import Dict, List
# Discovered scaling law:
# loss = a_g + b_g * log10(num_params) + c_g * log2(parallel_size)
# Functional form is shared across groups; coefficients (a_g, b_g, c_g) differ by group.
# Coefficients were fitted via least squares on the provided dataset.
# Per-group coefficients:
# stack: a=2.59963343, b=-0.16438025, c=-0.01656783
# pile: a=4.61839962, b=-0.28879470, c=-0.02553759
# Fallback pooled coefficients (all groups combined), used for unknown groups:
# pooled: a=3.60901653, b=-0.22658747, c=-0.02105271
_COEFFICIENTS: Dict[str, Dict[str, float]] = {
"stack": {"a": 2.59963343, "b": -0.16438025, "c": -0.01656783},
"pile": {"a": 4.61839962, "b": -0.28879470, "c": -0.02553759},
"__pooled__": {"a": 3.60901653, "b": -0.22658747, "c": -0.02105271},
}
def _predict_loss(num_params: float, parallel_size: float, coeffs: Dict[str, float]) -> float:
# Guardrails to avoid math domain issues
n = max(float(num_params), 1.0)
p = max(float(parallel_size), 1.0)
return (
coeffs["a"]
+ coeffs["b"] * math.log10(n)
+ coeffs["c"] * math.log2(p)
)
def law(input_data: List[Dict[str, float]], group: str) -> List[Dict[str, float]]:
"""
Predicts output variables based on input variables according to a discovered scaling law.
Args:
input_data: A list of dictionaries, where each dictionary is a single data
point containing input variable names as keys and their
corresponding values.
group: The name of the experimental group for which to make predictions.
The functional form of the law must be the same for all groups,
but the constant parameters/coefficients can differ per group.
Returns:
A list of dictionaries, corresponding to the input_data list, with each
dictionary containing the predicted output variable(s).
"""
coeffs = _COEFFICIENTS.get(group, _COEFFICIENTS["__pooled__"])
outputs: List[Dict[str, float]] = []
for row in input_data:
n = float(row.get("num_params", 0.0))
p = float(row.get("parallel_size", 1.0))
loss_hat = _predict_loss(n, p, coeffs)
outputs.append({"loss": float(loss_hat)})
return outputs