from __future__ import annotations
from typing import Dict, List
# Fixed parameters per group for the scaling law:
# loss = L0 + A * dense_parameter_count**(-a) + B * num_experts**(-b)
_PARAMS: dict[str, dict[str, float]] = {
'all_data': {'L0': -437.0171474, 'A': 57.79747244, 'a': 0.238345632, 'B': 439.0313445, 'b': 0.0001639537764},
}
def _predict_one(x: Dict[str, float], coeffs: Dict[str, float]) -> Dict[str, float]:
# Robustly get inputs, allow some aliasing of keys
def get_key(d: Dict[str, float], names):
for k in names:
if k in d:
return float(d[k])
# Try case-insensitive
lower = {kk.lower(): kk for kk in d.keys()}
for k in names:
if k.lower() in lower:
return float(d[lower[k.lower()]])
raise KeyError(f"Missing required key; tried aliases {names}")
dense = get_key(x, ['dense_parameter_count','dense_params','dense_parameters','non_expert_params'])
experts = get_key(x, ['num_experts','experts','n_experts','num_expert'])
dense = max(dense, 1e-12)
experts = max(experts, 1e-12)
L0 = coeffs['L0']; A = coeffs['A']; a = coeffs['a']; B = coeffs['B']; b = coeffs['b']
y = L0 + A * (dense ** (-a)) + B * (experts ** (-b))
return {'loss_validation': float(y)}
def law(input_data: list[dict[str, float]], group: str) -> list[dict[str, float]]:
"""
Predicts output variables based on input variables according to a discovered scaling law.
Args:
input_data: A list of dictionaries, where each dictionary is a single data
point containing input variable names as keys and their
corresponding values.
group: The name of the experimental group for which to make predictions.
The functional form of the law must be the same for all groups,
but the constant parameters/coefficients can differ per group.
Returns:
A list of dictionaries, corresponding to the input_data list, with each
dictionary containing the predicted output variable(s).
"""
if group not in _PARAMS:
# Fallback: if an unknown group appears, use the closest (by name) or the first
# Here we pick the first available as a conservative default
fallback_group = next(iter(_PARAMS.keys()))
coeffs = _PARAMS[fallback_group]
else:
coeffs = _PARAMS[group]
return [_predict_one(x, coeffs) for x in input_data]