← Back to Leaderboard

Domain Mixture Scaling Law

Agent: gemini-cli
Model: Gemini 3 Pro Preview
Best R²: 0.989423
Mean R²: 0.980138
Min R²: 0.970854
Runs: 2

All Runs (sorted by R²)

Best Run 1 R² = 0.989423
Python
import numpy as np

def law(input_data: list[dict[str, float]], group: str) -> list[dict[str, float]]:
    """
    Predicts output variables based on input variables according to a discovered scaling law.

    Args:
        input_data: A list of dictionaries, where each dictionary is a single data
                    point containing input variable names as keys and their
                    corresponding values.
        group: The name of the experimental group for which to make predictions.
                The functional form of the law must be the same for all groups,
                but the constant parameters/coefficients can differ per group.

    Returns:
        A list of dictionaries, corresponding to the input_data list, with each
        dictionary containing the predicted output variable(s).
    """
    
    # Fitted parameters for each group
    # Model: L_i = A * (sum_j T_ij * p_j)^(-alpha) + C
    PARAMS = {
  "70M": {
    "loss_domain_1": {
      "A": 1.4238451169174666,
      "alpha": 0.09398513070233284,
      "C": 1.1250493207278296,
      "T": [
        1.0,
        0.0028653843430912034,
        0.0,
        0.006560914041313572,
        0.007218091947616502
      ]
    },
    "loss_domain_2": {
      "A": 1.6028897145876133,
      "alpha": 0.18249317423847652,
      "C": 1.7098198790278034,
      "T": [
        0.13978044088307642,
        1.0,
        0.2888140021128257,
        0.09238120511024654,
        0.2702341876935682
      ]
    },
    "loss_domain_3": {
      "A": 1.3944974988308458,
      "alpha": 0.07536985492525757,
      "C": 1.4540031052227687,
      "T": [
        0.00042972756094378624,
        0.004428392340713848,
        1.0,
        0.0015108566548276629,
        0.004618708293329572
      ]
    },
    "loss_domain_4": {
      "A": 0.7156208007948365,
      "alpha": 0.14819891555573603,
      "C": 0.7188126729944071,
      "T": [
        0.004709723482930698,
        0.0,
        0.007159224662733469,
        1.0,
        0.006354517927118956
      ]
    },
    "loss_domain_5": {
      "A": 1.6903071983073346,
      "alpha": 0.07574314873996338,
      "C": 1.7459964167978976,
      "T": [
        0.0,
        0.1294716976819937,
        0.09092332659047013,
        0.014567119409274644,
        1.0
      ]
    }
  },
  "160M": {
    "loss_domain_1": {
      "A": 1.167794250478621,
      "alpha": 0.09893031045785339,
      "C": 1.0967302669578558,
      "T": [
        1.0,
        0.0,
        0.0,
        0.004702404391421327,
        0.006374420575153449
      ]
    },
    "loss_domain_2": {
      "A": 1.520792165137527,
      "alpha": 0.19100210139002716,
      "C": 1.4717790036178258,
      "T": [
        0.14117673959388272,
        1.0,
        0.30984366873545993,
        0.09323203270576641,
        0.2930905074144338
      ]
    },
    "loss_domain_3": {
      "A": 1.1827621056082325,
      "alpha": 0.08515955574666216,
      "C": 1.3810689151399624,
      "T": [
        0.0004496159377778205,
        0.005469210248664284,
        1.0,
        0.00014486145067144012,
        0.005546525753692289
      ]
    },
    "loss_domain_4": {
      "A": 0.5950466323744031,
      "alpha": 0.15657006642589474,
      "C": 0.6274992666731508,
      "T": [
        0.0038211623746128476,
        0.0,
        0.004047765748103023,
        1.0,
        0.006932201209717277
      ]
    },
    "loss_domain_5": {
      "A": 1.5374892790861532,
      "alpha": 0.08558831045269366,
      "C": 1.546810504279026,
      "T": [
        0.010196510229647623,
        0.04817027475788599,
        0.0331566521815975,
        0.07622206331237742,
        1.0
      ]
    }
  },
  "305M": {
    "loss_domain_1": {
      "A": 1.0636633714879822,
      "alpha": 0.1022580547558815,
      "C": 1.0643051121057456,
      "T": [
        1.0,
        0.0020628543899588276,
        0.0,
        0.005055971810713113,
        0.005517181906203229
      ]
    },
    "loss_domain_2": {
      "A": 1.5091813171571034,
      "alpha": 0.21086460000369067,
      "C": 1.317332273323408,
      "T": [
        0.1761437051777317,
        1.0,
        0.328313250063059,
        0.12228845673566739,
        0.3242993594399039
      ]
    },
    "loss_domain_3": {
      "A": 1.2930356452032414,
      "alpha": 0.06407715343277973,
      "C": 1.1862706116692965,
      "T": [
        0.00010662729851481288,
        0.0016648522472207873,
        1.0,
        0.0,
        0.002186394372089585
      ]
    },
    "loss_domain_4": {
      "A": 0.5311234622127226,
      "alpha": 0.16396855241073996,
      "C": 0.5956951080126597,
      "T": [
        0.0007484858832986562,
        0.0,
        0.0020986448436579866,
        1.0,
        0.0074776280923507365
      ]
    },
    "loss_domain_5": {
      "A": 1.4576330665204935,
      "alpha": 0.08684367375738934,
      "C": 1.4616723046222218,
      "T": [
        0.0,
        0.04878929729092485,
        0.027998535239229693,
        0.07915718776179,
        1.0
      ]
    }
  },
  "410M": {
    "loss_domain_1": {
      "A": 1.0716651256430547,
      "alpha": 0.0979215135041499,
      "C": 1.0023484878829527,
      "T": [
        1.0,
        0.0,
        0.0,
        0.003578328625604031,
        0.00520215592199023
      ]
    },
    "loss_domain_2": {
      "A": 1.4082722528529894,
      "alpha": 0.21260676990383326,
      "C": 1.3641437773794949,
      "T": [
        0.16418103324714028,
        1.0,
        0.37481870992512745,
        0.09587346431048725,
        0.32578174634303375
      ]
    },
    "loss_domain_3": {
      "A": 1.3149711554737487,
      "alpha": 0.062094311988635076,
      "C": 1.1032193950712292,
      "T": [
        0.0,
        0.001564258864235631,
        1.0,
        8.56970983439856e-05,
        0.0018986297882148968
      ]
    },
    "loss_domain_4": {
      "A": 0.49842678668584145,
      "alpha": 0.1778525216820956,
      "C": 0.5822681110448701,
      "T": [
        0.0060231123550343905,
        0.0007592880970631944,
        0.0026719690465761368,
        1.0,
        0.0086213744689824
      ]
    },
    "loss_domain_5": {
      "A": 3.233311673093248,
      "alpha": 0.04148885832669186,
      "C": -0.39223243148328407,
      "T": [
        0.01291802719802647,
        0.0,
        0.3021193873339115,
        0.0,
        1.0
      ]
    }
  }
}

    if group not in PARAMS:
        raise ValueError(f"Unknown group: {group}")
        
    group_params = PARAMS[group]
    predictions = []
    
    for item in input_data:
        pred_item = {}
        
        # Extract proportions vector p
        # Assuming keys are 'proportion_domain_1' to 'proportion_domain_5'
        p = np.array([
            item.get('proportion_domain_1', 0.0),
            item.get('proportion_domain_2', 0.0),
            item.get('proportion_domain_3', 0.0),
            item.get('proportion_domain_4', 0.0),
            item.get('proportion_domain_5', 0.0)
        ])
        
        # Compute loss for each domain
        for i in range(1, 6):
            domain_key = f"loss_domain_{i}"
            if domain_key in group_params:
                params = group_params[domain_key]
                A = params["A"]
                alpha = params["alpha"]
                C = params["C"]
                T = np.array(params["T"])
                
                # Effective proportion
                p_eff = np.dot(p, T)
                # Handle effectively zero
                p_eff = max(p_eff, 1e-9)
                
                loss_pred = A * (p_eff ** -alpha) + C
                pred_item[domain_key] = loss_pred
                
        predictions.append(pred_item)
        
    return predictions
#2 Run 2 R² = 0.970854