# Auto-generated scaling law implementation
# Model form: log-quadratic
# Features: 1, logE, logD, logE2, logD2, logE_logD
from math import log10
_COEFFICIENTS = {
"all_data": {
"weights": [
13.849928930727911,
-0.7685396744020193,
-2.266750584688975,
-0.005468530952045237,
0.10924879564098447,
0.0786451523043135
],
"r2": 0.9613252957444444,
"bic": -1120.0506497593492,
"n": 193
},
"_default_": {
"weights": [
13.849928930727911,
-0.7685396744020193,
-2.266750584688975,
-0.005468530952045237,
0.10924879564098447,
0.0786451523043135
],
"r2": 0.9613252957444444,
"bic": -1120.0506497593492,
"n": 193
}
}
_FEATURES = ["1", "logE", "logD", "logE2", "logD2", "logE_logD"]
def _predict_one(num_experts: float, dense_parameter_count: float, weights: list[float]) -> float:
# Guard against non-positive inputs for log
e = max(float(num_experts), 1e-12)
d = max(float(dense_parameter_count), 1e-12)
le = log10(e)
ld = log10(d)
# X = [1, le, ld, le^2, ld^2, le*ld]
x = [1.0, le, ld, le*le, ld*ld, le*ld]
return sum(w*v for w,v in zip(weights, x))
def law(input_data: list[dict[str, float]], group: str) -> list[dict[str, float]]:
"""
Predicts output variables based on input variables according to a discovered scaling law.
Args:
input_data: A list of dictionaries, where each dictionary is a single data
point containing input variable names as keys and their
corresponding values.
group: The name of the experimental group for which to make predictions.
The functional form of the law is the same for all groups, but
the coefficients can differ per group.
Returns:
A list of dictionaries, corresponding to the input_data list, with each
dictionary containing the predicted output variable(s).
"""
coeffs = _COEFFICIENTS.get(group)
if coeffs is None:
coeffs = _COEFFICIENTS.get("_default_")
weights = coeffs["weights"]
outputs = []
for row in input_data:
ne = float(row.get("num_experts", 0.0))
dp = float(row.get("dense_parameter_count", 0.0))
yhat = _predict_one(ne, dp, weights)
outputs.append({"loss_validation": float(yhat)})
return outputs