from __future__ import annotations
import math
from typing import List, Dict
FEATURES = ['log_params', 'log_tokens', 'log_unique_tokens', 'log_params:log_tokens', 'tokens_inv_sqrt']
GROUP_PARAMS = {
"all_data": {
"intercept": 21.017514457355812,
"coefs": [
-0.8745591289420206,
-0.5044278099541184,
-0.11831988273483411,
0.029264245542033336,
38424.77315308764
]
}
}
GLOBAL_PARAMS = {
"intercept": 21.017514457355812,
"coefs": [
-0.8745591289420206,
-0.5044278099541184,
-0.11831988273483411,
0.029264245542033336,
38424.77315308764
]
}
def _feature_vector(params: float, tokens: float, unique_tokens: float):
lp = math.log(params)
lt = math.log(tokens)
lu = math.log(unique_tokens)
vec = []
for feat in FEATURES:
if feat == "log_params":
vec.append(lp)
elif feat == "log_tokens":
vec.append(lt)
elif feat == "log_unique_tokens":
vec.append(lu)
elif feat == "log_params:log_tokens":
vec.append(lp*lt)
elif feat == "tokens_inv_sqrt":
vec.append(1.0/math.sqrt(tokens))
else:
vec.append(0.0)
return vec
def law(input_data: list[dict[str, float]], group: str) -> list[dict[str, float]]:
"""
Predicts output variables based on input variables according to a discovered scaling law.
Args:
input_data: A list of dictionaries, where each dictionary is a single data
point containing input variable names as keys and their
corresponding values.
group: The name of the experimental group for which to make predictions.
The functional form of the law must be the same for all groups,
but the constant parameters/coefficients can differ per group.
Returns:
A list of dictionaries, corresponding to the input_data list, with each
dictionary containing the predicted output variable(s).
"""
coeffs = GROUP_PARAMS.get(group, GLOBAL_PARAMS)
intercept = coeffs["intercept"]
beta = coeffs["coefs"]
out: list[dict[str, float]] = []
for row in input_data:
p = float(row["params"])
t = float(row["tokens"])
u = float(row["unique_tokens"])
x = _feature_vector(p, t, u)
y = intercept + sum(b*xi for b, xi in zip(beta, x))
out.append({"loss": float(y)})
return out