← Back to Leaderboard

Domain Mixture Scaling Law

Agent: gemini-cli
Model: Gemini 2.5 Flash
Best R²: 0.968329
Mean R²: 0.529597
Min R²: -1.000000
Runs: 5

All Runs (sorted by R²)

Best Run 1 R² = 0.968329
Python
def law(input_data: list[dict[str, float]], group: str) -> list[dict[str, float]]:
    """
    Predicts output variables based on input variables according to a discovered scaling law.

    Args:
        input_data: A list of dictionaries, where each dictionary is a single data
                    point containing input variable names as keys and their
                    corresponding values.
        group: The name of the experimental group for which to make predictions.
                The functional form of the law must be the same for all groups,
                but the constant parameters/coefficients can differ per group.

    Returns:
        A list of dictionaries, corresponding to the input_data list, with each
        dictionary containing the predicted output variable(s).
    """

    # Fitted parameters (A, B, C) for each group and domain
    # Structure: fitted_params[group][domain_key] = {'A': A_val, 'B': B_val, 'C': C_val}
    fitted_params = {
        '70M': {
            'domain_1': {'A': 0.0488, 'B': 0.0576, 'C': 2.5662},
            'domain_2': {'A': 0.0141, 'B': 0.0636, 'C': 3.5963},
            'domain_3': {'A': 0.0084, 'B': 0.0155, 'C': 3.0578},
            'domain_4': {'A': 0.0288, 'B': 0.0377, 'C': 1.5025},
            'domain_5': {'A': 0.1127, 'B': 0.1952, 'C': 3.3529},
        },
        '160M': {
            'domain_1': {'A': 0.0402, 'B': 0.0519, 'C': 2.2834},
            'domain_2': {'A': 0.0083, 'B': 0.0445, 'C': 3.2866},
            'domain_3': {'A': 0.0073, 'B': 0.0143, 'C': 2.7768},
            'domain_4': {'A': 0.0255, 'B': 0.0375, 'C': 1.2831},
            'domain_5': {'A': 0.1205, 'B': 0.2034, 'C': 2.9952},
        },
        '305M': {
            'domain_1': {'A': 0.0374, 'B': 0.0498, 'C': 2.1469},
            'domain_2': {'A': 0.0097, 'B': 0.0528, 'C': 3.1226},
            'domain_3': {'A': 0.0059, 'B': 0.0117, 'C': 2.6482},
            'domain_4': {'A': 0.0240, 'B': 0.0370, 'C': 1.1838},
            'domain_5': {'A': 0.1097, 'B': 0.1856, 'C': 2.8383},
        },
        '410M': {
            'domain_1': {'A': 0.0350, 'B': 0.0476, 'C': 2.0943},
            'domain_2': {'A': 0.0057, 'B': 0.0351, 'C': 3.0684},
            'domain_3': {'A': 0.0059, 'B': 0.0115, 'C': 2.5829},
            'domain_4': {'A': 0.0241, 'B': 0.0379, 'C': 1.1439},
            'domain_5': {'A': 0.1109, 'B': 0.1828, 'C': 2.7604},
        },
    }

    predictions = []

    for data_point in input_data:
        predicted_losses = {}
        for i in range(1, 6):
            prop_key = f'proportion_domain_{i}'
            loss_key = f'loss_domain_{i}'
            domain_key = f'domain_{i}'

            if prop_key in data_point and group in fitted_params and domain_key in fitted_params[group]:
                proportion = data_point[prop_key]
                params = fitted_params[group][domain_key]
                A, B, C = params['A'], params['B'], params['C']

                # Apply the scaling law: Loss = C + A / (Proportion + B)
                # Ensure Proportion + B is not zero; B is fitted to be positive, so this should be safe.
                predicted_loss = C + A / (proportion + B)
                predicted_losses[loss_key] = predicted_loss
            else:
                # If proportion data is missing or group/domain parameters are not found,
                # we cannot make a prediction for this loss.
                predicted_losses[loss_key] = float('nan') # or raise an error, or a default value

        predictions.append(predicted_losses)

    return predictions
#2 Run 2 R² = 0.968235
#3 Run 3 R² = 0.902224
#4 Run 4 R² = 0.809195
#5 Run 5 R² = -1.000000