from typing import List, Dict
import math
# Scaling law chosen (Model B):
# loss_validation = L0 + C * (dense_parameter_count ** -alpha) * (num_experts ** -beta)
# Coefficients are per-group; fall back to 'default' if group not listed.
_PARAMS_BY_GROUP = {
# Fitted on provided dataset (group == 'all_data') using robust least squares
'all_data': {
'L0': 1.61701771,
'C': 43.47565665,
'alpha': 0.19898560,
'beta': 0.07398273,
},
# Default fallback (same as all_data)
'__default__': {
'L0': 1.61701771,
'C': 43.47565665,
'alpha': 0.19898560,
'beta': 0.07398273,
},
}
def _predict_one(e: float, nd: float, coeffs: Dict[str, float]) -> float:
e = max(1.0, float(e))
nd = max(1.0, float(nd))
return coeffs['L0'] + coeffs['C'] * (nd ** (-coeffs['alpha'])) * (e ** (-coeffs['beta']))
def law(input_data: List[Dict[str, float]], group: str) -> List[Dict[str, float]]:
"""
Predicts output variables based on input variables according to a discovered scaling law.
Args:
input_data: A list of dictionaries, where each dictionary is a single data
point containing input variable names as keys and their
corresponding values.
group: The name of the experimental group for which to make predictions.
The functional form of the law must be the same for all groups,
but the constant parameters/coefficients can differ per group.
Returns:
A list of dictionaries, corresponding to the input_data list, with each
dictionary containing the predicted output variable(s).
"""
coeffs = _PARAMS_BY_GROUP.get(group, _PARAMS_BY_GROUP['__default__'])
out = []
for row in input_data:
e = row.get('num_experts', 1.0)
nd = row.get('dense_parameter_count', 1.0)
pred = _predict_one(e, nd, coeffs)
out.append({'loss_validation': float(pred)})
return out