Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions metrics/mse/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ Optional arguments:
- `raw_values` returns a full set of errors in case of multioutput input.
- `uniform_average` means that the errors of all outputs are averaged with uniform weight.
- the array-like value defines weights used to average errors.
- `squared` (`bool`): If `True` returns MSE value, if `False` returns RMSE (Root Mean Squared Error). The default value is `True`.
- `squared` (`bool`): If `True` returns MSE value, if `False` returns RMSE (Root Mean Squared Error). The default value is `True`. Note: internally uses `root_mean_squared_error` from sklearn when `squared=False`, ensuring compatibility with sklearn >= 1.6.


### Output Values
Expand Down Expand Up @@ -82,7 +82,7 @@ Example with the `uniform_average` config:
{'mse': 0.375}
```

Example with `squared = True`, which returns the RMSE:
Example with `squared = False`, which returns the RMSE:
```python
>>> mse_metric = evaluate.load("mse")
>>> predictions = [2.5, 0.0, 2, 8]
Expand Down
6 changes: 5 additions & 1 deletion metrics/mse/mse.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@

squared : bool, default=True
If True returns MSE value, if False returns RMSE (Root Mean Squared Error) value.
Note: When squared=False, uses sklearn's root_mean_squared_error function, which is
required for compatibility with sklearn >= 1.6 (the squared parameter was removed
from mean_squared_error in sklearn 1.6).

Returns:
mse : mean squared error.
Expand Down Expand Up @@ -94,7 +97,8 @@ def _info(self):
inputs_description=_KWARGS_DESCRIPTION,
features=datasets.Features(self._get_feature_types()),
reference_urls=[
"https://scikit-learn.org/stable/modules/generated/sklearn.metrics.mean_squared_error.html"
"https://scikit-learn.org/stable/modules/generated/sklearn.metrics.mean_squared_error.html",
"https://scikit-learn.org/stable/modules/generated/sklearn.metrics.root_mean_squared_error.html",
],
)

Expand Down