# Auto-generated scaling law for parallel model training
# Formula:
# loss = c0_g + c1_g * N^(-alpha) + c2_g * P^(-beta) + c3_g * (N^(-alpha) * P^(-beta))
# with exponents alpha, beta shared across groups.
from typing import List, Dict
def law(input_data: list[dict[str, float]], group: str) -> list[dict[str, float]]:
"""
Predicts output variables based on input variables according to a discovered scaling law.
Args:
input_data: A list of dictionaries, where each dictionary is a single data
point containing input variable names as keys and their
corresponding values.
group: The name of the experimental group for which to make predictions.
The functional form of the law must be the same for all groups,
but the constant parameters/coefficients can differ per group.
Returns:
A list of dictionaries, corresponding to the input_data list, with each
dictionary containing the predicted output variable(s).
"""
ALPHA = 0.25125
BETA = 0.513333333333
COEFS: dict[str, list[float]] = {
"pile": [1.38133216084, 97.066659052, 0.0772792357873, 4.4654429466],
"stack": [0.755424951209, 54.0083122959, 0.0433880237503, 4.26602562249],
}
# Fallback: mean coefficients if group unknown
if COEFS:
_avg = [sum(cs[i] for cs in COEFS.values())/len(COEFS) for i in range(4)]
else:
_avg = [0.0, 0.0, 0.0, 0.0]
coeffs = COEFS.get(group, _avg)
out: list[dict[str, float]] = []
for row in input_data:
# Obtain inputs with common aliases
n = row.get('num_params', None)
if n is None:
n = row.get('n_params', row.get('params', row.get('n')))
p = row.get('parallel_size', None)
if p is None:
p = row.get('p', row.get('degree', row.get('mp_size', row.get('world_size'))))
if n is None or p is None:
raise KeyError("Each input dict must contain 'num_params' and 'parallel_size' (or recognized aliases).")
n = max(float(n), 1e-12)
p = max(float(p), 1e-12)
x1 = n ** (-ALPHA)
x2 = p ** (-BETA)
x3 = x1 * x2
c0, c1, c2, c3 = coeffs
pred = float(c0 + c1*x1 + c2*x2 + c3*x3)
out.append({'loss': pred})
return out