diff --git a/metrics/mse/README.md b/metrics/mse/README.md index 6f4144ec..019b0298 100644 --- a/metrics/mse/README.md +++ b/metrics/mse/README.md @@ -47,7 +47,7 @@ Optional arguments: - `raw_values` returns a full set of errors in case of multioutput input. - `uniform_average` means that the errors of all outputs are averaged with uniform weight. - the array-like value defines weights used to average errors. -- `squared` (`bool`): If `True` returns MSE value, if `False` returns RMSE (Root Mean Squared Error). The default value is `True`. +- `squared` (`bool`): If `True` returns MSE value, if `False` returns RMSE (Root Mean Squared Error). The default value is `True`. Note: internally uses `root_mean_squared_error` from sklearn when `squared=False`, ensuring compatibility with sklearn >= 1.6. ### Output Values @@ -82,7 +82,7 @@ Example with the `uniform_average` config: {'mse': 0.375} ``` -Example with `squared = True`, which returns the RMSE: +Example with `squared = False`, which returns the RMSE: ```python >>> mse_metric = evaluate.load("mse") >>> predictions = [2.5, 0.0, 2, 8] diff --git a/metrics/mse/mse.py b/metrics/mse/mse.py index 92e9ca31..e3fd09f6 100644 --- a/metrics/mse/mse.py +++ b/metrics/mse/mse.py @@ -56,6 +56,9 @@ squared : bool, default=True If True returns MSE value, if False returns RMSE (Root Mean Squared Error) value. + Note: When squared=False, uses sklearn's root_mean_squared_error function, which is + required for compatibility with sklearn >= 1.6 (the squared parameter was removed + from mean_squared_error in sklearn 1.6). Returns: mse : mean squared error. @@ -94,7 +97,8 @@ def _info(self): inputs_description=_KWARGS_DESCRIPTION, features=datasets.Features(self._get_feature_types()), reference_urls=[ - "https://scikit-learn.org/stable/modules/generated/sklearn.metrics.mean_squared_error.html" + "https://scikit-learn.org/stable/modules/generated/sklearn.metrics.mean_squared_error.html", + "https://scikit-learn.org/stable/modules/generated/sklearn.metrics.root_mean_squared_error.html", ], )