from __future__ import annotations
from typing import Dict, List
import math
# Discovered scaling law (shared functional form across groups):
# loss_validation = L0[group] + C[group] * num_experts**(-p[group]) * dense_parameter_count**(-q[group])
# Coefficients were fit via nonlinear least squares on the provided dataset.
# See /app/explain.md for methodology and diagnostics.
# Group-specific parameters. If an unknown group is requested, we fall back to "all_data".
_PARAMS: Dict[str, Dict[str, float]] = {
# Fitted on the provided data (single group present in the dataset)
"all_data": {
"L0": 1.6170181290494012,
"C": 43.47571056885845,
"p": 0.073982766969121,
"q": 0.19898568380705728,
},
}
_FALLBACK_GROUP = "all_data"
def _get_params(group: str) -> Dict[str, float]:
# Use exact match if present; otherwise fall back to the default group.
return _PARAMS.get(group, _PARAMS[_FALLBACK_GROUP])
def _predict_one(x: Dict[str, float], params: Dict[str, float]) -> float:
# Extract inputs with basic validation and safety clamps
E = float(x.get("num_experts", 0.0))
D = float(x.get("dense_parameter_count", 0.0))
# Guard against non-positive inputs to power operations
eps = 1e-12
E = max(E, eps)
D = max(D, eps)
L0 = params["L0"]
C = params["C"]
p = params["p"]
q = params["q"]
return float(L0 + C * (E ** (-p)) * (D ** (-q)))
def law(input_data: List[Dict[str, float]], group: str) -> List[Dict[str, float]]:
"""
Predicts output variables based on input variables according to a discovered scaling law.
Args:
input_data: A list of dictionaries, where each dictionary is a single data
point containing input variable names as keys and their
corresponding values.
group: The name of the experimental group for which to make predictions.
The functional form of the law must be the same for all groups,
but the constant parameters/coefficients can differ per group.
Returns:
A list of dictionaries, corresponding to the input_data list, with each
dictionary containing the predicted output variable(s).
"""
params = _get_params(group)
outputs: List[Dict[str, float]] = []
for x in input_data:
y = _predict_one(x, params)
outputs.append({"loss_validation": y})
return outputs