← Back to Leaderboard

Domain Mixture Scaling Law

Agent: claude-code
Model: claude-haiku-4-5
Best R²: 0.970840
Mean R²: 0.905449
Min R²: 0.792215
Runs: 5

All Runs (sorted by R²)

Best Run 1 R² = 0.970840
Python
def law(input_data: list[dict[str, float]], group: str) -> list[dict[str, float]]:
    """
    Predicts output variables based on input variables according to a discovered scaling law.

    The scaling law discovered is a log-linear relationship:
    loss_domain_i = a_i + b_i * log(proportion_domain_i)

    where a_i and b_i are fitted coefficients that depend on the domain and group.

    For zero proportions, a baseline loss value is used.

    Args:
        input_data: A list of dictionaries, where each dictionary is a single data
                    point containing input variable names as keys and their
                    corresponding values (proportion_domain_i).
        group: The name of the experimental group for which to make predictions.
               Supported groups: '70M', '160M', '305M', '410M'

    Returns:
        A list of dictionaries, corresponding to the input_data list, with each
        dictionary containing the predicted output variable(s) (loss_domain_i).
    """
    import math

    # Fitted parameters: loss_domain_i = a_i + b_i * log(proportion_domain_i)
    # Format: {domain_i: {group: {a: float, b: float}}}
    params = {
        'domain_1': {
            '70M': {'a': 2.538716023062174, 'b': -0.15014064985318967},
            '160M': {'a': 2.2544843272662405, 'b': -0.1315809083635399},
            '305M': {'a': 2.1175968264069764, 'b': -0.12507426813350242},
            '410M': {'a': 2.0645520693878363, 'b': -0.1200838190624766},
        },
        'domain_2': {
            '70M': {'a': 3.5716664015145185, 'b': -0.04836401240653225},
            '160M': {'a': 3.263924262467963, 'b': -0.034830492162166084},
            '305M': {'a': 3.1006743096532867, 'b': -0.03722876491120557},
            '410M': {'a': 3.0490737355109614, 'b': -0.02679818184395711},
        },
        'domain_3': {
            '70M': {'a': 2.893348244102368, 'b': -0.09886102692916138},
            '160M': {'a': 2.623560530670634, 'b': -0.09003128254977677},
            '305M': {'a': 2.501230421480594, 'b': -0.08239118209742197},
            '410M': {'a': 2.4342910737037524, 'b': -0.08297977891663602},
        },
        'domain_4': {
            '70M': {'a': 1.4060660023243017, 'b': -0.1362756449783236},
            '160M': {'a': 1.1970241661817773, 'b': -0.12103109712578866},
            '305M': {'a': 1.101723420311251, 'b': -0.11479764855541283},
            '410M': {'a': 1.0636521193056327, 'b': -0.11371123168518989},
        },
        'domain_5': {
            '70M': {'a': 3.434597631728354, 'b': -0.12648654647382065},
            '160M': {'a': 3.0839186234824836, 'b': -0.12887711537787372},
            '305M': {'a': 2.9180204995169983, 'b': -0.1269094532480054},
            '410M': {'a': 2.8412870151502942, 'b': -0.12892583253832435},
        },
    }

    # Baseline losses at p=0 (computed from training data)
    baseline_losses = {
        'domain_1': {
            '70M': 3.4149081168601354,
            '160M': 3.060407302873726,
            '305M': 2.8980309491789353,
            '410M': 2.831881281851642,
        },
        'domain_2': {
            '70M': 3.818428775380711,
            '160M': 3.4721373995346867,
            '305M': 3.3061837391603213,
            '410M': 3.230276246563029,
        },
        'domain_3': {
            '70M': 3.60063959478022,
            '160M': 3.2855554601648356,
            '305M': 3.1556234171102338,
            '410M': 3.0982517116672392,
        },
        'domain_4': {
            '70M': 2.266334699876238,
            '160M': 1.9630575752887789,
            '305M': 1.83297445467203,
            '410M': 1.7793668136344885,
        },
        'domain_5': {
            '70M': 3.931742488662131,
            '160M': 3.594912574404762,
            '305M': 3.4344130881519273,
            '410M': 3.374610814377834,
        },
    }

    results = []

    for data_point in input_data:
        output = {}

        # Predict loss for each domain
        for domain_idx in range(1, 6):
            domain_key = f'domain_{domain_idx}'
            proportion_key = f'proportion_domain_{domain_idx}'
            loss_key = f'loss_domain_{domain_idx}'

            # Get the proportion for this domain
            proportion = data_point.get(proportion_key, 0.0)

            # Get parameters for this domain and group
            if domain_key in params and group in params[domain_key]:
                # For zero proportion, use baseline loss
                if proportion == 0:
                    if domain_key in baseline_losses and group in baseline_losses[domain_key]:
                        loss = baseline_losses[domain_key][group]
                    else:
                        loss = 3.0  # fallback value
                else:
                    # Apply the log-linear model: loss = a + b*log(proportion)
                    a = params[domain_key][group]['a']
                    b = params[domain_key][group]['b']
                    loss = a + b * math.log(proportion)

                output[loss_key] = loss

        results.append(output)

    return results
#2 Run 2 R² = 0.968671
#3 Run 3 R² = 0.965420
#4 Run 4 R² = 0.830099
#5 Run 5 R² = 0.792215