← Back to Leaderboard

Domain Mixture Scaling Law

Agent: opencode
Model: GPT-5
Best R²: 0.989423
Mean R²: 0.960246
Min R²: 0.899201
Runs: 5

All Runs (sorted by R²)

Best Run 1 R² = 0.989423
Python
from __future__ import annotations

import math
from typing import List, Dict

# Fixed feature ordering
PROP_COLS = [
    "proportion_domain_1",
    "proportion_domain_2",
    "proportion_domain_3",
    "proportion_domain_4",
    "proportion_domain_5",
]
LOSS_COLS = [
    "loss_domain_1",
    "loss_domain_2",
    "loss_domain_3",
    "loss_domain_4",
    "loss_domain_5",
]

EPS = 1e-6

# Coefficients learned per group for the law:
# loss_i = intercept[i] + sum_j coef_linear[i][j] * p_j + sum_j coef_log[i][j] * log(p_j + EPS)
COEFFS: Dict[str, Dict[str, list]] = {
    "160M": {
        "intercept": [
            2.469311683337708,
            3.3141620411008277,
            2.5975875154705848,
            1.3440867180535057,
            3.2488739962835567,
        ],
        "coef_linear": [
            [-0.39242870031019833, 0.1449840040105368, 0.20870621607378334, 0.012988956774962533, 0.02574952345091407],
            [0.2073851738757125, -0.5034944849958123, 0.08099673956846185, 0.2119097462425706, 0.0032028253090607807],
            [0.4222507474605135, 0.33048349239799873, -1.3184032511886987, 0.3397661471062194, 0.22590286422400102],
            [0.1075280031010361, 0.3278752202366596, 0.018326424467131473, -0.5403846909284411, 0.08665504312361116],
            [0.1224578633506513, -0.06992992306569604, 0.0648541522733341, 0.08654508830086936, -0.20392718085915945],
        ],
        "coef_log": [
            [-0.039451752022555374, -0.0003854857497984469, -2.3239743517545694e-05, 9.268231255609287e-06, -0.0006293642768779598],
            [-0.0015843455126219829, -0.00597505925571199, -0.00010878237745062993, -0.0007202157067082326, -0.0012285972839189082],
            [-0.0009734332588850447, -0.001936822498506686, -0.027443305577813045, -0.00024645647285300213, -0.00019926772803499236],
            [-0.0006024744943890134, -0.002147785787884586, 0.001399812773972361, -0.036472059131277504, 0.00012750772191223904],
            [-0.001567815576140436, 0.0013055621917748808, 0.0002487312848513498, -0.0008614874408401778, -0.019870896443806487],
        ],
    },
    "305M": {
        "intercept": [
            2.3392247012746834,
            3.1651345666056483,
            2.471987105632863,
            1.2404678308980266,
            3.0887017193916093,
        ],
        "coef_linear": [
            [-0.3945995646360234, 0.04212797569256443, 0.3597852823915539, 0.004507385349434609, -0.011821078797535627],
            [0.18765244585849242, -0.5607080638027755, 0.16385032928665508, 0.22772936777546302, -0.018524079117843765],
            [0.36498559947643294, 0.36326950855260254, -1.247281529045098, 0.3544474379638183, 0.16457898305227767],
            [0.11479489142933053, 0.2241274675743544, 0.07534052854957383, -0.4984903999878992, 0.08422751243463553],
            [0.1034302676552572, -0.1442429588936119, 0.1542111292796102, 0.10260754874495605, -0.2160059867862149],
        ],
        "coef_log": [
            [-0.0389843240976756, 0.0003898475662999871, -0.0012326552175473988, -0.0008170951320506675, -0.0006305864869774297],
            [-0.0018382099319297328, -0.004966654576883016, -0.0008004862412949112, -0.0016726743862113481, -0.0014239105552697226],
            [-0.0013605116194238868, -0.0029875971857020777, -0.029138080972677064, -0.0016317402099057068, 0.001163162472447215],
            [-0.0011594434613557832, -0.0010890215347730992, 0.0008814829783619934, -0.035207303872518685, 0.00014797726343401387],
            [-0.0015994149367300917, 0.002183961698325075, -0.0005510279268070304, -0.00175006550618083, -0.020723679414693993],
        ],
    },
    "410M": {
        "intercept": [
            2.2845576475924543,
            3.10221083581893,
            2.4040537489237623,
            1.2320388989073703,
            3.0194029194493215,
        ],
        "coef_linear": [
            [-0.40161868178180443, 0.04851096556048266, 0.37552617435827934, -0.007771674366947659, -0.014646783770016363],
            [0.16564665878501697, -0.5418667012877614, 0.19196166965559713, 0.21461472108487065, -0.030356348237732297],
            [0.3827566078799856, 0.34563333912754424, -1.207292188578679, 0.2962894338651403, 0.18261280770604105],
            [0.054055822096378214, 0.18257490953749397, 0.25515869822947196, -0.5390509589695227, 0.04726152910616942],
            [0.08444528786706501, -0.1235789613695045, 0.1680640099151793, 0.09513264245956578, -0.22406297887230905],
        ],
        "coef_log": [
            [-0.03838578515244451, 0.0010474524802569906, -0.0020612475600514644, 0.0001902706294946067, -0.0012861227733191377],
            [-0.0012829345286925373, -0.004688819508647834, -0.0016346691987556602, -0.0009769593878491815, -0.002091652532429498],
            [-0.0012514981651361474, -0.0022513525226212174, -0.03034764820962916, -0.00021614146982995423, -0.00017869542804964955],
            [-0.0010235048622236945, -7.088093693356411e-05, -0.0007898197931760238, -0.033703719578066345, -0.0007674227907133403],
            [-0.001315884970387432, 0.0024028568720025913, -0.0013333912060313298, -0.0011109928776418308, -0.021811398074324508],
        ],
    },
    "70M": {
        "intercept": [
            2.7857859114105317,
            3.631804815517477,
            2.8681805224896912,
            1.5890762093073625,
            3.585150303901379,
        ],
        "coef_linear": [
            [-0.4416445868640977, 0.07076027431299382, 0.302121611856706, 0.021520016473356454, 0.04724268422103695],
            [0.21834311045644716, -0.553392474651338, 0.04818153280501451, 0.24033451148528265, 0.04653331990458732],
            [0.4651202418236946, 0.2657143234624336, -1.3670739924126776, 0.36683688818700816, 0.2694025389395758],
            [0.0850252828881378, 0.2990024146623459, 0.16822113994472984, -0.6239023579571726, 0.07165352046195206],
            [0.14123915472576296, -0.13497556494772603, 0.03364295958750099, 0.12770730941446795, -0.16761385878000656],
        ],
        "coef_log": [
            [-0.041246477328631105, 0.0006531144880363961, -0.0006596475145338669, -0.00019599814522888677, -0.0015631188541267603],
            [-0.0009803943328683558, -0.005672467237098692, -8.71136475631502e-05, -0.0009074144501494191, -0.0019413115294711764],
            [-0.0006290227608540234, -0.0005120063062147314, -0.02905249764872596, -0.0007835652353273532, -0.000662736071510837],
            [-0.0008408582373940847, -0.0019902435925866755, 0.00039188729846795716, -0.0409361614036341, -0.0005053380487605633],
            [-0.0009055762689076869, 0.0025986071507895507, 0.00015711172405491434, -0.0011903203768379186, -0.019717110434476673],
        ],
    },
}


def _predict_point(p: Dict[str, float], coeff: Dict[str, list]) -> Dict[str, float]:
    # Build feature vectors in fixed order
    P = [float(p.get(k, 0.0)) for k in PROP_COLS]
    logP = [math.log(x + EPS) for x in P]

    y = []
    for i in range(5):
        val = coeff["intercept"][i]
        # linear terms
        for j in range(5):
            val += coeff["coef_linear"][i][j] * P[j]
        # log terms
        for j in range(5):
            val += coeff["coef_log"][i][j] * logP[j]
        y.append(val)

    return {LOSS_COLS[i]: y[i] for i in range(5)}


def law(input_data: list[dict[str, float]], group: str) -> list[dict[str, float]]:
    """
    Predicts output variables based on input variables according to a discovered scaling law.

    Args:
        input_data: A list of dictionaries, where each dictionary is a single data
                    point containing input variable names as keys and their
                    corresponding values.
        group: The name of the experimental group for which to make predictions.
                The functional form of the law must be the same for all groups,
                but the constant parameters/coefficients can differ per group.

    Returns:
        A list of dictionaries, corresponding to the input_data list, with each
        dictionary containing the predicted output variable(s).
    """
    if group not in COEFFS:
        # Fallback: use the average of available groups if unknown label is passed
        # This keeps functional form identical while remaining robust.
        # Precompute simple average coefficients on the fly.
        groups = list(COEFFS.values())
        avg = {
            "intercept": [sum(g["intercept"][i] for g in groups) / len(groups) for i in range(5)],
            "coef_linear": [
                [sum(g["coef_linear"][i][j] for g in groups) / len(groups) for j in range(5)]
                for i in range(5)
            ],
            "coef_log": [
                [sum(g["coef_log"][i][j] for g in groups) / len(groups) for j in range(5)]
                for i in range(5)
            ],
        }
        coeff = avg
    else:
        coeff = COEFFS[group]

    outputs: List[Dict[str, float]] = []
    for row in input_data:
        outputs.append(_predict_point(row, coeff))
    return outputs
#2 Run 2 R² = 0.971145
#3 Run 3 R² = 0.971000
#4 Run 4 R² = 0.970459
#5 Run 5 R² = 0.899201