# Auto-generated scaling law implementation
# Discovered via ridge regression (manual) on polynomial-in-log features
# Do not modify the function signature.
from typing import List, Dict
import math
FEATURES = ['bias', 'L', 'B', 'D', 'P', 'L2', 'B2', 'D2', 'P2', 'LB', 'LD', 'LP', 'BD', 'BP', 'DP']
COEFS_BY_GROUP = {'all_data': {'coef': [16.624581903612846, 0.2627109539547664, 0.8995972963599023, -2.109340807436253, -0.3416462681138454, 0.14849884087182352, 0.1269746750542109, 0.13485667144489863, 0.07916170471632446, -0.08188202638168432, -0.024745326001810515, 0.12219666925411721, -0.12293565944271072, -0.0525403200519685, -0.08250175820236673], 'uses_bias_feature': True}}
def _make_features_one(x: Dict[str, float]):
# Compute polynomial-in-log features
L = math.log10(x["lr"])
B = math.log10(x["bsz"])
D = math.log10(x["data_size"])
P = math.log10(x["non_embedding_param_size"])
feats = {
"bias": 1.0,
"L": L, "B": B, "D": D, "P": P,
"L2": L*L, "B2": B*B, "D2": D*D, "P2": P*P,
"LB": L*B, "LD": L*D, "LP": L*P,
"BD": B*D, "BP": B*P, "DP": D*P,
}
return [feats[k] for k in FEATURES]
def _predict_one(x: Dict[str, float], group: str) -> float:
# Fallback to any known group's coefficients if unseen group
g = group if group in COEFS_BY_GROUP else (list(COEFS_BY_GROUP.keys())[0] if COEFS_BY_GROUP else None)
if g is None:
raise ValueError("No coefficients available for prediction.")
coef = COEFS_BY_GROUP[g]["coef"]
feats = _make_features_one(x)
return sum(c*f for c, f in zip(coef, feats))
def law(input_data: List[Dict[str, float]], group: str) -> List[Dict[str, float]]:
"""
Predicts output variables based on input variables according to a discovered scaling law.
Args:
input_data: A list of dictionaries, where each dictionary is a single data
point containing input variable names as keys and their
corresponding values.
group: The name of the experimental group for which to make predictions.
The functional form of the law is the same for all groups,
but the coefficients can differ per group.
Returns:
A list of dictionaries, corresponding to the input_data list, with each
dictionary containing the predicted output variable(s).
"""
outputs = []
for x in input_data:
y = _predict_one(x, group)
outputs.append({"lm_loss": float(y)})
return outputs