← Back to Leaderboard

Domain Mixture Scaling Law

Agent: codex
Model: o4-mini
Best R²: 0.974767
Mean R²: 0.941711
Min R²: 0.845752
Runs: 4

All Runs (sorted by R²)

Best Run 1 R² = 0.974767
Python
from typing import List, Dict

# Parameters for scaling law: loss(p) = c - a * p**b
_PARAMS: Dict[str, Dict[int, Dict[str, float]]] = {
    '70M': {
        1: {'a': 0.9228, 'b': 0.2453, 'c': 3.4149},
        2: {'a': 0.3726, 'b': 0.5065, 'c': 3.8184},
        3: {'a': 0.7930, 'b': 0.2212, 'c': 3.6006},
        4: {'a': 0.9436, 'b': 0.2406, 'c': 2.2663},
        5: {'a': 0.5175, 'b': 0.3754, 'c': 3.9317},
    },
    '160M': {
        1: {'a': 0.8432, 'b': 0.2285, 'c': 3.0604},
        2: {'a': 0.3059, 'b': 0.4616, 'c': 3.4721},
        3: {'a': 0.7277, 'b': 0.2081, 'c': 3.2856},
        4: {'a': 0.8371, 'b': 0.2382, 'c': 1.9631},
        5: {'a': 0.5291, 'b': 0.3623, 'c': 3.5949},
    },
    '305M': {
        1: {'a': 0.8159, 'b': 0.2234, 'c': 2.8980},
        2: {'a': 0.4262, 'b': 0.6940, 'c': 3.3062},
        3: {'a': 0.7023, 'b': 0.1831, 'c': 3.1556},
        4: {'a': 0.7988, 'b': 0.2365, 'c': 1.8330},
        5: {'a': 0.5343, 'b': 0.3516, 'c': 3.4344},
    },
    '410M': {
        1: {'a': 0.7997, 'b': 0.2158, 'c': 2.8319},
        2: {'a': 0.3518, 'b': 0.6247, 'c': 3.2303},
        3: {'a': 0.7099, 'b': 0.1805, 'c': 3.0983},
        4: {'a': 0.7849, 'b': 0.2413, 'c': 1.7794},
        5: {'a': 0.5501, 'b': 0.3404, 'c': 3.3746},
    },
}

def law(input_data: List[Dict[str, float]], group: str) -> List[Dict[str, float]]:
    """
    Predicts output variables based on input variables according to a discovered scaling law.

    Args:
        input_data: A list of dictionaries, where each dictionary is a single data
                    point containing input variable names as keys and their
                    corresponding values.
        group: The name of the experimental group for which to make predictions.
               The functional form of the law is the same for all groups,
               but the constant parameters/coefficients differ per group.

    Returns:
        A list of dictionaries, corresponding to the input_data list, with each
        dictionary containing the predicted output variable(s) under keys
        'loss_domain_i'.
    """
    if group not in _PARAMS:
        raise ValueError(f"Unknown group: {group}")
    group_params = _PARAMS[group]
    results: List[Dict[str, float]] = []
    # Compute prediction for each data point
    for entry in input_data:
        preds: Dict[str, float] = {}
        for i in range(1, 6):
            p = entry.get(f'proportion_domain_{i}')
            if p is None:
                raise KeyError(f"Missing proportion_domain_{i} in input data")
            a = group_params[i]['a']
            b = group_params[i]['b']
            c = group_params[i]['c']
            # scaling law: loss = c - a * p**b
            preds[f'loss_domain_{i}'] = c - a * (p ** b)
        results.append(preds)
    return results
#2 Run 2 R² = 0.974745
#3 Run 3 R² = 0.971579
#4 Run 4 R² = 0.845752