← Back to Leaderboard

Domain Mixture Scaling Law

Agent: codex
Model: GPT-5
Best R²: 0.990428
Mean R²: 0.933260
Min R²: 0.834132
Runs: 5

All Runs (sorted by R²)

Best Run 1 R² = 0.990428
Python
from __future__ import annotations

import math
from typing import Dict, List


def law(input_data: List[Dict[str, float]], group: str) -> List[Dict[str, float]]:
    """
    Predicts output variables based on input variables according to a discovered scaling law.

    The law models each domain's validation loss as the sum of:
      - a group- and domain-specific intercept a_i,
      - a group- and domain-specific coefficient b_i times log(p_i + eps), capturing
        diminishing returns from allocating more mixture proportion to the same domain,
      - plus a linear combination of the proportions of the other domains (j != i),
        with group- and domain-specific coefficients c_{i,j}.

    Mathematically, for domain i in {1..5}:
        loss_i = a_i + b_i * log(p_i + eps) + sum_{j != i} c_{i,j} * p_j

    where p_k are the mixture proportions (sum_k p_k = 1), and eps is a small constant
    to handle zero proportions inside the logarithm.

    Args:
        input_data: List of dicts with keys 'proportion_domain_1'..'proportion_domain_5'.
        group: One of the experimental groups. The same functional form is used for all
               groups, with coefficients differing per group.

    Returns:
        A list of dicts with keys 'loss_domain_1'..'loss_domain_5'.
    """

    # Small constant to avoid log(0)
    EPS = 1e-6

    # Coefficients fitted per group on the provided dataset (/app/data), using the
    # model: loss_i = a_i + b_i * log(p_i + EPS) + sum_{j != i} c_{i,j} * p_j
    # For convenience, linear coefficients are stored as a full 5-length vector per domain
    # with 0.0 for the self-domain (j == i) entry.
    COEFFS = {
        "70M": {
            1: {"a": 2.352400, "b": -0.041342, "c": [0.000000, 0.552302, 0.679733, 0.457510, 0.478500]},
            2: {"a": 3.119185, "b": -0.005609, "c": [0.733329, 0.000000, 0.567223, 0.760307, 0.571576]},
            3: {"a": 1.557687, "b": -0.029500, "c": [1.776484, 1.574088, 0.000000, 1.672027, 1.590520]},
            4: {"a": 1.005729, "b": -0.040741, "c": [0.682161, 0.804593, 0.768164, 0.000000, 0.680742]},
            5: {"a": 3.401418, "b": -0.019938, "c": [0.282951, 0.204621, 0.280657, 0.244292, 0.000000]},
        },
        "160M": {
            1: {"a": 2.084419, "b": -0.039436, "c": [0.000000, 0.515541, 0.590549, 0.410446, 0.414215]},
            2: {"a": 2.848965, "b": -0.005760, "c": [0.664815, 0.000000, 0.533358, 0.698111, 0.486927]},
            3: {"a": 1.375788, "b": -0.028472, "c": [1.645880, 1.472320, 0.000000, 1.592583, 1.466833]},
            4: {"a": 0.822570, "b": -0.036176, "c": [0.633280, 0.747330, 0.680942, 0.000000, 0.623930]},
            5: {"a": 3.044954, "b": -0.020112, "c": [0.288934, 0.234711, 0.313982, 0.265677, 0.000000]},
        },
        "305M": {
            1: {"a": 1.965386, "b": -0.039011, "c": [0.000000, 0.461256, 0.591688, 0.362942, 0.378769]},
            2: {"a": 2.675656, "b": -0.004898, "c": [0.681773, 0.000000, 0.558797, 0.717652, 0.506549]},
            3: {"a": 1.389474, "b": -0.030900, "c": [1.455301, 1.326467, 0.000000, 1.424874, 1.288538]},
            4: {"a": 0.758123, "b": -0.034855, "c": [0.586244, 0.671620, 0.645107, 0.000000, 0.580221]},
            5: {"a": 2.880988, "b": -0.021162, "c": [0.278675, 0.225879, 0.321137, 0.249162, 0.000000]},
        },
        "410M": {
            1: {"a": 1.904173, "b": -0.038724, "c": [0.000000, 0.497929, 0.520547, 0.389682, 0.371875]},
            2: {"a": 2.648743, "b": -0.005145, "c": [0.632228, 0.000000, 0.458498, 0.688205, 0.451025]},
            3: {"a": 1.311117, "b": -0.031575, "c": [1.474932, 1.346313, 0.000000, 1.429078, 1.297670]},
            4: {"a": 0.726224, "b": -0.033638, "c": [0.560347, 0.717670, 0.657147, 0.000000, 0.569629]},
            5: {"a": 2.802291, "b": -0.021963, "c": [0.276436, 0.261534, 0.247464, 0.274675, 0.000000]},
        },
    }

    # Fallback: if an unknown group is provided, use the closest available group
    # by parameterization (default to the smallest model "70M").
    params_by_group = COEFFS.get(group)
    if params_by_group is None:
        params_by_group = COEFFS["70M"]

    outputs: List[Dict[str, float]] = []
    for row in input_data:
        # Read proportions in a fixed order
        p = [float(row.get(f"proportion_domain_{i}", 0.0)) for i in range(1, 6)]
        # Normalize defensively in case inputs are not perfectly normalized
        s = sum(p)
        if s > 0:
            p = [pi / s for pi in p]

        pred: Dict[str, float] = {}
        for i in range(1, 6):
            par = params_by_group[i]
            a = par["a"]
            b = par["b"]
            c = par["c"]  # length-5, zero at index i-1
            log_term = math.log(max(p[i - 1], 0.0) + EPS)
            linear_term = sum(c[j] * p[j] for j in range(5))
            y = a + b * log_term + linear_term
            pred[f"loss_domain_{i}"] = float(y)

        outputs.append(pred)

    return outputs
#2 Run 2 R² = 0.971446
#3 Run 3 R² = 0.971092
#4 Run 4 R² = 0.899201
#5 Run 5 R² = 0.834132