from typing import List, Dict
def law(input_data: List[Dict[str, float]], group: str) -> List[Dict[str, float]]:
"""
Predicts output variables based on input variables according to a discovered scaling law.
The discovered law models validation loss as an asymptotic multiplicative power law
of dense (non-expert) parameters and the number of experts:
loss_validation = A[group] + B[group] * (dense_parameter_count ** b[group]) * (num_experts ** c[group])
Coefficients were fitted per experimental group on the provided dataset. If an
unknown group is requested, the function falls back to the "all_data" coefficients
(or the first available set if "all_data" is not present).
Args:
input_data: A list of dictionaries, where each dictionary is a single data
point containing input variable names as keys and their
corresponding values. Expected keys are:
- 'num_experts'
- 'dense_parameter_count'
group: The name of the experimental group for which to make predictions.
The functional form of the law is the same for all groups, but the
constant parameters/coefficients can differ per group.
Returns:
A list of dictionaries, corresponding to the input_data list, with each
dictionary containing the predicted output variable(s):
- 'loss_validation'
"""
# Per-group coefficients fitted from the dataset under /app/data
# Format: group -> (A, B, b, c)
COEFFS = {
# Fitted on 193 points (group == 'all_data')
# RMSE ~ 0.05255 on training data
"all_data": (1.6057342355089967, 41.28866734867788, -0.19547840587237095, -0.07426497458304816),
}
# Resolve which group's coefficients to use
if group in COEFFS:
A, B, b, c = COEFFS[group]
elif "all_data" in COEFFS:
A, B, b, c = COEFFS["all_data"]
else:
# Fallback: use any available group's coefficients deterministically
first_key = next(iter(COEFFS))
A, B, b, c = COEFFS[first_key]
out: List[Dict[str, float]] = []
for row in input_data:
D = float(row.get("dense_parameter_count", 0.0))
E = float(row.get("num_experts", 0.0))
# Guard against non-positive values; the fitted law expects positives.
if D <= 0 or E <= 0:
y_hat = float("nan")
else:
y_hat = A + B * (D ** b) * (E ** c)
out.append({"loss_validation": float(y_hat)})
return out