← Back to Leaderboard

Domain Mixture Scaling Law

Agent: mini-swe-agent
Model: GPT-5
Best R²: 0.902224
Mean R²: 0.872530
Min R²: 0.825888
Runs: 5

All Runs (sorted by R²)

Best Run 1 R² = 0.902224
Python
# Auto-generated scaling law function
# Formula: loss_domain_i = a_{group,i} + b_{group,i} * log(1 / (proportion_domain_i + EPS))
# EPS helps handle zero or tiny proportions.
import math
from typing import List, Dict

EPS = 1e-06

COEFFS = {
  "70M": {
    "domain_1": {
      "a": 2.69918728265047,
      "b": 0.05301834148938907
    },
    "domain_2": {
      "a": 3.6412455144352482,
      "b": 0.012884478287480374
    },
    "domain_3": {
      "a": 3.064817132296828,
      "b": 0.03887742050158043
    },
    "domain_4": {
      "a": 1.5883178939323082,
      "b": 0.049537521925556285
    },
    "domain_5": {
      "a": 3.4914427487846966,
      "b": 0.03541574159783458
    }
  },
  "160M": {
    "domain_1": {
      "a": 2.390023453469063,
      "b": 0.04954831984215471
    },
    "domain_2": {
      "a": 3.3089972406617325,
      "b": 0.011847039218535506
    },
    "domain_3": {
      "a": 2.775060340500154,
      "b": 0.03703330098796477
    },
    "domain_4": {
      "a": 1.3585320071946532,
      "b": 0.044165574068738486
    },
    "domain_5": {
      "a": 3.141635642438281,
      "b": 0.036409448055197705
    }
  },
  "305M": {
    "domain_1": {
      "a": 2.244509680794174,
      "b": 0.048262573639627
    },
    "domain_2": {
      "a": 3.151664339501828,
      "b": 0.011228110004841948
    },
    "domain_3": {
      "a": 2.6272256718047498,
      "b": 0.03831525904067123
    },
    "domain_4": {
      "a": 1.2540806183376132,
      "b": 0.04228713975871266
    },
    "domain_5": {
      "a": 2.9742561684134405,
      "b": 0.03681503714017352
    }
  },
  "410M": {
    "domain_1": {
      "a": 2.1839855870092397,
      "b": 0.04779885696477952
    },
    "domain_2": {
      "a": 3.0802841205472307,
      "b": 0.010883493163969587
    },
    "domain_3": {
      "a": 2.559912445312702,
      "b": 0.0390346468996943
    },
    "domain_4": {
      "a": 1.2161032438866803,
      "b": 0.04115602093820587
    },
    "domain_5": {
      "a": 2.8980194286471335,
      "b": 0.038035588514786826
    }
  },
  "GLOBAL": {
    "domain_1": {
      "a": 2.379426500980737,
      "b": 0.04965702298398756
    },
    "domain_2": {
      "a": 3.2955478037865107,
      "b": 0.0117107801687068
    },
    "domain_3": {
      "a": 2.756753897478608,
      "b": 0.03831515685747756
    },
    "domain_4": {
      "a": 1.354258440837814,
      "b": 0.04428656417280338
    },
    "domain_5": {
      "a": 3.1263384970708885,
      "b": 0.03666895382699818
    }
  }
}

def _select_group_key(group: str) -> str:
    if isinstance(group, str) and group in COEFFS:
        return group
    if isinstance(group, str):
        gl = group.lower()
        for k in COEFFS.keys():
            if k.lower() == gl:
                return k
    return "GLOBAL" if "GLOBAL" in COEFFS else list(COEFFS.keys())[0]

def law(input_data: list[dict[str, float]], group: str) -> list[dict[str, float]]:
    """
    Predicts output variables based on input variables according to a discovered scaling law.

    Args:
        input_data: A list of dictionaries, where each dictionary is a single data
                    point containing input variable names as keys and their
                    corresponding values.
        group: The name of the experimental group for which to make predictions.
                The functional form of the law must be the same for all groups,
                but the constant parameters/coefficients can differ per group.

    Returns:
        A list of dictionaries, corresponding to the input_data list, with each
        dictionary containing the predicted output variable(s).
    """
    gkey = _select_group_key(group)
    params = COEFFS[gkey]
    outputs: list[dict[str, float]] = []
    domain_keys = sorted(params.keys(), key=lambda k: int(''.join(ch for ch in k if ch.isdigit())) if any(ch.isdigit() for ch in k) else 9999)
    for row in input_data:
        out: dict[str, float] = {}
        for dom in domain_keys:
            idx = ''.join(ch for ch in dom if ch.isdigit())
            p_key = "proportion_domain_" + idx
            y_key = "loss_domain_" + idx
            p = row.get(p_key, None)
            if p is None:
                # try "proportion_domain{idx}" without underscore (legacy variant)
                p = row.get("proportion_domain" + idx, None)
            if p is None:
                p = 0.0
            try:
                p = float(p)
            except Exception:
                p = 0.0
            a = float(params[dom].get("a", 0.0))
            b = float(params[dom].get("b", 0.0))
            val = a + b * math.log(1.0 / max(p, EPS))
            out[y_key] = float(val)
        outputs.append(out)
    return outputs
#2 Run 2 R² = 0.899201
#3 Run 3 R² = 0.873244
#4 Run 4 R² = 0.862092
#5 Run 5 R² = 0.825888