from typing import List, Dict
# Parameters for scaling law: loss(p) = c - a * p**b
_PARAMS: Dict[str, Dict[int, Dict[str, float]]] = {
'70M': {
1: {'a': 0.9228, 'b': 0.2453, 'c': 3.4149},
2: {'a': 0.3726, 'b': 0.5065, 'c': 3.8184},
3: {'a': 0.7930, 'b': 0.2212, 'c': 3.6006},
4: {'a': 0.9436, 'b': 0.2406, 'c': 2.2663},
5: {'a': 0.5175, 'b': 0.3754, 'c': 3.9317},
},
'160M': {
1: {'a': 0.8432, 'b': 0.2285, 'c': 3.0604},
2: {'a': 0.3059, 'b': 0.4616, 'c': 3.4721},
3: {'a': 0.7277, 'b': 0.2081, 'c': 3.2856},
4: {'a': 0.8371, 'b': 0.2382, 'c': 1.9631},
5: {'a': 0.5291, 'b': 0.3623, 'c': 3.5949},
},
'305M': {
1: {'a': 0.8159, 'b': 0.2234, 'c': 2.8980},
2: {'a': 0.4262, 'b': 0.6940, 'c': 3.3062},
3: {'a': 0.7023, 'b': 0.1831, 'c': 3.1556},
4: {'a': 0.7988, 'b': 0.2365, 'c': 1.8330},
5: {'a': 0.5343, 'b': 0.3516, 'c': 3.4344},
},
'410M': {
1: {'a': 0.7997, 'b': 0.2158, 'c': 2.8319},
2: {'a': 0.3518, 'b': 0.6247, 'c': 3.2303},
3: {'a': 0.7099, 'b': 0.1805, 'c': 3.0983},
4: {'a': 0.7849, 'b': 0.2413, 'c': 1.7794},
5: {'a': 0.5501, 'b': 0.3404, 'c': 3.3746},
},
}
def law(input_data: List[Dict[str, float]], group: str) -> List[Dict[str, float]]:
"""
Predicts output variables based on input variables according to a discovered scaling law.
Args:
input_data: A list of dictionaries, where each dictionary is a single data
point containing input variable names as keys and their
corresponding values.
group: The name of the experimental group for which to make predictions.
The functional form of the law is the same for all groups,
but the constant parameters/coefficients differ per group.
Returns:
A list of dictionaries, corresponding to the input_data list, with each
dictionary containing the predicted output variable(s) under keys
'loss_domain_i'.
"""
if group not in _PARAMS:
raise ValueError(f"Unknown group: {group}")
group_params = _PARAMS[group]
results: List[Dict[str, float]] = []
# Compute prediction for each data point
for entry in input_data:
preds: Dict[str, float] = {}
for i in range(1, 6):
p = entry.get(f'proportion_domain_{i}')
if p is None:
raise KeyError(f"Missing proportion_domain_{i} in input data")
a = group_params[i]['a']
b = group_params[i]['b']
c = group_params[i]['c']
# scaling law: loss = c - a * p**b
preds[f'loss_domain_{i}'] = c - a * (p ** b)
results.append(preds)
return results