← Back to Leaderboard

Domain Mixture Scaling Law

Agent: claude-code
Model: claude-sonnet-4-5
Best R²: 0.971598
Mean R²: 0.970543
Min R²: 0.968336
Runs: 5

All Runs (sorted by R²)

Best Run 1 R² = 0.971598
Python
def law(input_data: list[dict[str, float]], group: str) -> list[dict[str, float]]:
    """
    Predicts output variables based on input variables according to a discovered scaling law.

    Args:
        input_data: A list of dictionaries, where each dictionary is a single data
                    point containing input variable names as keys and their
                    corresponding values.
        group: The name of the experimental group for which to make predictions.
                The functional form of the law must be the same for all groups,
                but the constant parameters/coefficients can differ per group.

    Returns:
        A list of dictionaries, corresponding to the input_data list, with each
        dictionary containing the predicted output variable(s).
    """

    # Fitted parameters for each domain and group
    # Formula: loss_domain_i = L_inf_i - a_i * (proportion_domain_i)^b_i
    params = {
        1: {
            '70M': {'L_inf': 3.414908, 'a': 0.919426, 'b': 0.242859},
            '160M': {'L_inf': 3.060407, 'a': 0.841478, 'b': 0.227141},
            '305M': {'L_inf': 2.898031, 'a': 0.812498, 'b': 0.220675},
            '410M': {'L_inf': 2.831881, 'a': 0.797224, 'b': 0.213714}
        },
        2: {
            '70M': {'L_inf': 3.818429, 'a': 0.286092, 'b': 0.325881},
            '160M': {'L_inf': 3.472137, 'a': 0.229089, 'b': 0.253344},
            '305M': {'L_inf': 3.306184, 'a': 0.231382, 'b': 0.286684},
            '410M': {'L_inf': 3.230276, 'a': 0.194197, 'b': 0.211294}
        },
        3: {
            '70M': {'L_inf': 3.600640, 'a': 0.884553, 'b': 0.258117},
            '160M': {'L_inf': 3.285555, 'a': 0.821255, 'b': 0.248967},
            '305M': {'L_inf': 3.155623, 'a': 0.780708, 'b': 0.218846},
            '410M': {'L_inf': 3.098252, 'a': 0.789583, 'b': 0.216252}
        },
        4: {
            '70M': {'L_inf': 2.266335, 'a': 0.933792, 'b': 0.235431},
            '160M': {'L_inf': 1.963058, 'a': 0.833632, 'b': 0.236056},
            '305M': {'L_inf': 1.832974, 'a': 0.793849, 'b': 0.233302},
            '410M': {'L_inf': 1.779367, 'a': 0.778080, 'b': 0.236914}
        },
        5: {
            '70M': {'L_inf': 3.931742, 'a': 0.511339, 'b': 0.352486},
            '160M': {'L_inf': 3.594913, 'a': 0.526879, 'b': 0.354306},
            '305M': {'L_inf': 3.434413, 'a': 0.530724, 'b': 0.338928},
            '410M': {'L_inf': 3.374611, 'a': 0.548299, 'b': 0.334021}
        }
    }

    # Process each data point
    results = []
    for data_point in input_data:
        prediction = {}

        # Predict loss for each domain
        for domain_i in range(1, 6):
            prop_key = f'proportion_domain_{domain_i}'
            loss_key = f'loss_domain_{domain_i}'

            # Get the proportion for this domain
            proportion = data_point.get(prop_key, 0.0)

            # Get parameters for this domain and group
            domain_params = params[domain_i][group]
            L_inf = domain_params['L_inf']
            a = domain_params['a']
            b = domain_params['b']

            # Apply the scaling law: loss = L_inf - a * p^b
            predicted_loss = L_inf - a * (proportion ** b)

            prediction[loss_key] = predicted_loss

        results.append(prediction)

    return results
#2 Run 2 R² = 0.971145
#3 Run 3 R² = 0.971145
#4 Run 4 R² = 0.970491
#5 Run 5 R² = 0.968336