from __future__ import annotations
from typing import Dict, List
import math
# Discovered scaling law (shared functional form across groups):
# loss = L + K * (P**alpha * E**beta) ** (-gamma)
# where:
# P = dense_parameter_count (float, > 0)
# E = num_experts (float, > 0)
# Parameters (L, K, gamma, alpha, beta) are group-specific constants.
# Fitted parameters per group from the provided dataset.
# Values are rounded to 6 significant decimals for stability/readability.
_PARAMS_BY_GROUP: Dict[str, tuple[float, float, float, float, float]] = {
# group: (L, K, gamma, alpha, beta)
"all_data": (
1.616974, # L
43.469602, # K
0.190978, # gamma
1.041879, # alpha
0.387373, # beta
),
}
def _predict_loss(P: float, E: float, params: tuple[float, float, float, float, float]) -> float:
L, K, gamma, alpha, beta = params
# Guard against non-positive inputs; fall back to returning L if invalid.
if P <= 0 or E <= 0:
return float(L)
# Compute effective scale and apply the power-law decay.
# Use logs for numerical stability: (P**alpha * E**beta)**(-gamma) = exp(-gamma * (alpha*ln P + beta*ln E))
s_log = alpha * math.log(P) + beta * math.log(E)
decay = math.exp(-gamma * s_log)
return float(L + K * decay)
def law(input_data: list[dict[str, float]], group: str) -> list[dict[str, float]]:
"""
Predicts output variables based on input variables according to a discovered scaling law.
Args:
input_data: A list of dictionaries, where each dictionary is a single data
point containing input variable names as keys and their
corresponding values.
group: The name of the experimental group for which to make predictions.
The functional form of the law must be the same for all groups,
but the constant parameters/coefficients can differ per group.
Returns:
A list of dictionaries, corresponding to the input_data list, with each
dictionary containing the predicted output variable(s).
"""
# Select parameters for the provided group. If unseen, fall back to a reasonable default.
# Default: use the parameters fitted on the aggregate group if available, otherwise a safe baseline.
if group in _PARAMS_BY_GROUP:
params = _PARAMS_BY_GROUP[group]
elif "all_data" in _PARAMS_BY_GROUP:
params = _PARAMS_BY_GROUP["all_data"]
else:
# Conservative fallback (keeps loss near a plausible constant if no params are known)
params = (2.0, 1.0, 0.2, 1.0, 0.5)
outputs: List[Dict[str, float]] = []
for row in input_data:
P = float(row.get("dense_parameter_count", 0.0))
E = float(row.get("num_experts", 0.0))
pred = _predict_loss(P, E, params)
outputs.append({"loss_validation": pred})
return outputs