diff --git a/search/tests/test_utils.py b/search/tests/test_utils.py new file mode 100644 index 00000000..7f55c133 --- /dev/null +++ b/search/tests/test_utils.py @@ -0,0 +1,46 @@ +"""Tests for utility functions in search.utils module.""" + +import unittest +from ddt import ddt, data, unpack + +from search.utils import normalize_bool + + +@ddt +class TestNormalizeBool(unittest.TestCase): + """Test cases for normalize_bool function.""" + + @data( + (True, True), + (False, False), + ) + @unpack + def test_boolean_values(self, value, expected): + assert normalize_bool(value) is expected + + @data('y', 'Y', 'yes', 'YES', 't', 'T', 'true', 'TRUE', 'on', 'ON', '1') + def test_string_truthy_values(self, value): + assert normalize_bool(value) is True + + @data('n', 'N', 'no', 'NO', 'f', 'F', 'false', 'FALSE', 'off', 'OFF', '0') + def test_string_falsy_values(self, value): + assert normalize_bool(value) is False + + @data('invalid', '10') + def test_invalid_string_values(self, value): + with self.assertRaises(ValueError): + normalize_bool(value) + + @data( + (1, True), + (0, False), + (100, True), + ([], False), + ([1, 2, 3], True), + ({}, False), + ({'key': 'value'}, True), + (None, False), + ) + @unpack + def test_other_types(self, value, expected): + assert normalize_bool(value) is expected diff --git a/search/utils.py b/search/utils.py index ded4a759..28e98ccb 100644 --- a/search/utils.py +++ b/search/utils.py @@ -107,3 +107,20 @@ def end_time_string(self): def elapsed_time(self): """ Return the elapsed time """ return (self._end_time - self._start_time).seconds + + +def normalize_bool(value): + """ Normalize a value to a boolean. """ + if isinstance(value, bool): + return value + + if isinstance(value, str): + value = value.lower() + if value in ('y', 'yes', 't', 'true', 'on', '1'): + return True + if value in ('n', 'no', 'f', 'false', 'off', '0'): + return False + + raise ValueError(f"Invalid truth value: '{value}'") + + return bool(value) diff --git a/search/views.py b/search/views.py index 4d0499e9..89db933b 100644 --- a/search/views.py +++ b/search/views.py @@ -11,6 +11,7 @@ from eventtracking import tracker as track from .api import perform_search, course_discovery_search, course_discovery_filter_fields from .initializer import SearchInitializer +from .utils import normalize_bool # log appears to be standard name used for logger log = logging.getLogger(__name__) @@ -181,7 +182,7 @@ def _course_discovery(request, is_multivalue=False): status_code = 500 search_term = request.POST.get("search_string", None) - enable_course_sorting_by_start_date = request.POST.get("enable_course_sorting_by_start_date", False) + enable_course_sorting_by_start_date = normalize_bool(request.POST.get("enable_course_sorting_by_start_date", False)) try: size, from_, page = _process_pagination_values(request)