From 173054b1c075a5424787c70d70a7488cab61805b Mon Sep 17 00:00:00 2001 From: Karla Saur Date: Mon, 3 Apr 2023 23:12:17 +0000 Subject: [PATCH 01/12] attempting to implement infrequent_if_exist in skl OHE --- .../_one_hot_encoder_implementations.py | 18 ++++++++++++++++-- .../sklearn/one_hot_encoder.py | 14 ++++++++++++-- 2 files changed, 28 insertions(+), 4 deletions(-) diff --git a/hummingbird/ml/operator_converters/_one_hot_encoder_implementations.py b/hummingbird/ml/operator_converters/_one_hot_encoder_implementations.py index 22b6e64bd..6f3822b10 100644 --- a/hummingbird/ml/operator_converters/_one_hot_encoder_implementations.py +++ b/hummingbird/ml/operator_converters/_one_hot_encoder_implementations.py @@ -22,10 +22,12 @@ class OneHotEncoderString(PhysicalOperator, torch.nn.Module): Because we are dealing with tensors, strings require additional length information for processing. """ - def __init__(self, logical_operator, categories, device, extra_config={}): + def __init__(self, logical_operator, categories, handle_unknown, infrequent, device, extra_config={}): super(OneHotEncoderString, self).__init__(logical_operator, transformer=True) self.num_columns = len(categories) + self.handle_unknown = handle_unknown + self.infrequent = infrequent self.max_word_length = max([max([len(c) for c in cat]) for cat in categories]) # Strings are casted to int32, therefore we need to properly size the tensor to me dividable by 4. @@ -74,10 +76,12 @@ class OneHotEncoder(PhysicalOperator, torch.nn.Module): Class implementing OneHotEncoder operators for ints in PyTorch. """ - def __init__(self, logical_operator, categories, device): + def __init__(self, logical_operator, categories, handle_unknown, infrequent, device): super(OneHotEncoder, self).__init__(logical_operator, transformer=True) self.num_columns = len(categories) + self.handle_unknown = handle_unknown + self.infrequent = infrequent condition_tensors = [] for arr in categories: @@ -87,6 +91,16 @@ def __init__(self, logical_operator, categories, device): def forward(self, *x): encoded_tensors = [] + + if self.handle_unknown == "ignore": + pass + elif self.handle_unknown == "infrequent_if_exist": + pass + elif self.handle_unknown == "error": + pass + else: + raise RuntimeError("Unsupported handle_unknown setting: {0}".format(self.handle_unknown)) + if len(x) > 1: assert len(x) == self.num_columns diff --git a/hummingbird/ml/operator_converters/sklearn/one_hot_encoder.py b/hummingbird/ml/operator_converters/sklearn/one_hot_encoder.py index a03f2b39e..53af1f988 100644 --- a/hummingbird/ml/operator_converters/sklearn/one_hot_encoder.py +++ b/hummingbird/ml/operator_converters/sklearn/one_hot_encoder.py @@ -29,6 +29,14 @@ def convert_sklearn_one_hot_encoder(operator, device, extra_config): """ assert operator is not None, "Cannot convert None operator" + # scikit-learn >= 1.1 with handle_unknown = 'frequent_if_exist' + if hasattr(operator.raw_operator, "infrequent_categories_"): + infrequent = operator.raw_operator.infrequent_categories_ + else: + infrequent = None + + # TODO: What to do about min_frequency and max_categories? Either support them or raise an error. + if all( [ np.array(c).dtype == object or np.array(c).dtype.kind in constants.SUPPORTED_STRING_TYPES @@ -36,9 +44,11 @@ def convert_sklearn_one_hot_encoder(operator, device, extra_config): ] ): categories = [[str(x) for x in c.tolist()] for c in operator.raw_operator.categories_] - return OneHotEncoderString(operator, categories, device, extra_config) + return OneHotEncoderString(operator, categories, operator.raw_operator.handle_unknown, + infrequent, device, extra_config) else: - return OneHotEncoder(operator, operator.raw_operator.categories_, device) + return OneHotEncoder(operator, operator.raw_operator.categories_, operator.raw_operator.handle_unknown, + infrequent, device) register_converter("SklearnOneHotEncoder", convert_sklearn_one_hot_encoder) From 9df54601856fe1596ffd24e147f35e9ca243aa28 Mon Sep 17 00:00:00 2001 From: Karla Saur Date: Tue, 4 Apr 2023 16:22:17 +0000 Subject: [PATCH 02/12] adding testcase that shows the failure --- .../sklearn/one_hot_encoder.py | 1 + .../test_sklearn_one_hot_encoder_converter.py | 38 ++++++++++++++++++- 2 files changed, 38 insertions(+), 1 deletion(-) diff --git a/hummingbird/ml/operator_converters/sklearn/one_hot_encoder.py b/hummingbird/ml/operator_converters/sklearn/one_hot_encoder.py index 53af1f988..89516573a 100644 --- a/hummingbird/ml/operator_converters/sklearn/one_hot_encoder.py +++ b/hummingbird/ml/operator_converters/sklearn/one_hot_encoder.py @@ -36,6 +36,7 @@ def convert_sklearn_one_hot_encoder(operator, device, extra_config): infrequent = None # TODO: What to do about min_frequency and max_categories? Either support them or raise an error. + # see https://github.com/scikit-learn/scikit-learn/blob/main/sklearn/preprocessing/_encoders.py#L178 if all( [ diff --git a/tests/test_sklearn_one_hot_encoder_converter.py b/tests/test_sklearn_one_hot_encoder_converter.py index daabb371e..7c909cfbd 100644 --- a/tests/test_sklearn_one_hot_encoder_converter.py +++ b/tests/test_sklearn_one_hot_encoder_converter.py @@ -4,10 +4,11 @@ import unittest import numpy as np -import torch +import sklearn from sklearn.preprocessing import OneHotEncoder import hummingbird.ml +from packaging.version import Version, parse class TestSklearnOneHotEncoderConverter(unittest.TestCase): def test_model_one_hot_encoder_int(self): @@ -92,5 +93,40 @@ def test_model_one_hot_encoder_ts_string_not_mod4_len(self): np.testing.assert_allclose(model.transform(data).todense(), pytorch_model.transform(data), rtol=1e-06, atol=1e-06) + @unittest.skipIf(parse(sklearn.__version__) < Version("1.1"), "Skipping test because sklearn version is too old.") + def test_infrequent_if_exists(self): + + # This test is a copy of the test in sklearn. + # https://github.com/scikit-learn/scikit-learn/blob/ + # ecb9a70e82d4ee352e2958c555536a395b53d2bd/sklearn/preprocessing/tests/test_encoders.py#L868 + + X_train = np.array([["a"] * 5 + ["b"] * 2000 + ["c"] * 10 + ["d"] * 3]).T + model = OneHotEncoder( + categories=[["a", "b", "c", "d"]], + handle_unknown="infrequent_if_exist", + sparse_output=False, + min_frequency=15, + + ).fit(X_train) + np.testing.assert_array_equal(model.infrequent_categories_, [["a", "c", "d"]]) + + + pytorch_model = hummingbird.ml.convert(model, "torch", device="cpu") + self.assertIsNotNone(pytorch_model) + + X_test = [["b"], ["a"], ["c"], ["d"], ["e"]] + expected = np.array([[1, 0], [0, 1], [0, 1], [0, 1], [0, 1]]) + orig = model.transform(X_test) + np.testing.assert_allclose(expected, orig) + + + hb = pytorch_model.transform(X_test) + + print("In progress. This is where it fails.") + print("orig: ", orig) + print("hb: ", hb) + np.testing.assert_allclose(orig, hb, rtol=1e-06, atol=1e-06) + np.testing.assert_allclose(orig.shape, hb.shape, rtol=1e-06, atol=1e-06) + if __name__ == "__main__": unittest.main() From 8b0eb2f4f23e74d4979d2fe7f1896968dae5ebb4 Mon Sep 17 00:00:00 2001 From: Karla Saur Date: Tue, 4 Apr 2023 16:33:39 +0000 Subject: [PATCH 03/12] adding temporary notes --- .../ml/operator_converters/sklearn/one_hot_encoder.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/hummingbird/ml/operator_converters/sklearn/one_hot_encoder.py b/hummingbird/ml/operator_converters/sklearn/one_hot_encoder.py index 89516573a..3a49ae4e9 100644 --- a/hummingbird/ml/operator_converters/sklearn/one_hot_encoder.py +++ b/hummingbird/ml/operator_converters/sklearn/one_hot_encoder.py @@ -35,8 +35,12 @@ def convert_sklearn_one_hot_encoder(operator, device, extra_config): else: infrequent = None - # TODO: What to do about min_frequency and max_categories? Either support them or raise an error. + # TODO: What to do about min_frequency and max_categories? + # If I understand correctly, they are only used prior to "fit", and we won't need them for inference. + # Both min_frequency and max_categories trigger the creation of the "infrequent" categories, but then + # are not used again. So, we can ignore them for HB....i think? # see https://github.com/scikit-learn/scikit-learn/blob/main/sklearn/preprocessing/_encoders.py#L178 + # and the comment on line 503 same file. if all( [ From 89ba28ce3a0c4ced08225fc3502674c7fc29f5d7 Mon Sep 17 00:00:00 2001 From: Karla Saur Date: Tue, 4 Apr 2023 22:30:40 +0000 Subject: [PATCH 04/12] test working for int --- .../_one_hot_encoder_implementations.py | 77 ++++++++++++++++--- .../test_sklearn_one_hot_encoder_converter.py | 60 +++++++++++---- 2 files changed, 110 insertions(+), 27 deletions(-) diff --git a/hummingbird/ml/operator_converters/_one_hot_encoder_implementations.py b/hummingbird/ml/operator_converters/_one_hot_encoder_implementations.py index 6f3822b10..f67ffd5fe 100644 --- a/hummingbird/ml/operator_converters/_one_hot_encoder_implementations.py +++ b/hummingbird/ml/operator_converters/_one_hot_encoder_implementations.py @@ -26,9 +26,9 @@ def __init__(self, logical_operator, categories, handle_unknown, infrequent, dev super(OneHotEncoderString, self).__init__(logical_operator, transformer=True) self.num_columns = len(categories) - self.handle_unknown = handle_unknown - self.infrequent = infrequent self.max_word_length = max([max([len(c) for c in cat]) for cat in categories]) + self.handle_unknown = handle_unknown + self.mask = None # Strings are casted to int32, therefore we need to properly size the tensor to me dividable by 4. while self.max_word_length % 4 != 0: @@ -57,17 +57,54 @@ def __init__(self, logical_operator, categories, handle_unknown, infrequent, dev self.condition_tensors = torch.nn.Parameter(torch.IntTensor(condition_tensors), requires_grad=False) self.categories_idx = categories_idx + if infrequent is not None: + infrequent_tensors = [] + categories_idx = [0] + for arr in infrequent: + cats = ( + np.array(arr, dtype="|S" + str(self.max_word_length)) # Encode objects into 4 byte strings. + .view("int32") + .reshape(-1, self.max_word_length // 4) + .tolist() + ) + # We merge all categories for all columns into a single tensor + infrequent_tensors.extend(cats) + # Since all categories are merged together, we need to track of indexes to retrieve them at inference time. + categories_idx.append(categories_idx[-1] + len(cats)) + self.infrequent_tensors = torch.nn.Parameter(torch.IntTensor(infrequent_tensors), requires_grad=False) + + # We need to create a mask to filter out infrequent categories. + self.mask = torch.nn.ParameterList([]) + for i in range(len(self.condition_tensors[0])): + if self.condition_tensors[0][i] not in self.infrequent_tensors[0]: + self.mask.append(torch.nn.Parameter(self.condition_tensors[0][i], requires_grad=False)) + else: + self.infrequent_tensors = None + def forward(self, x): encoded_tensors = [] + + # TODO: implement 'error' case separately + if self.handle_unknown == "ignore" or self.handle_unknown == "error": + compare_tensors = self.condition_tensors + elif self.handle_unknown == "infrequent_if_exist": + compare_tensors = self.mask if self.mask is not None else self.condition_tensors + else: + raise RuntimeError("Unsupported handle_unknown setting: {0}".format(self.handle_unknown)) + for i in range(self.num_columns): # First we fetch the condition for the particular column. - conditions = self.condition_tensors[self.categories_idx[i] : self.categories_idx[i + 1], :].view( + conditions = compare_tensors[self.categories_idx[i] : self.categories_idx[i + 1], :].view( 1, -1, self.max_word_length // 4 ) # Differently than the numeric case where eq is enough, here we need to aggregate per object (dim = 2) # because objects can span multiple integers. We use product here since all ints must match to get encoding of 1. encoded_tensors.append(torch.prod(torch.eq(x[:, i : i + 1, :], conditions), dim=2)) + # if self.infrequent_tensors is not None, then append another tensor that is the "not" of the sum of the encoded tensors. + if self.infrequent_tensors is not None: + encoded_tensors.append(torch.logical_not(torch.sum(torch.stack(encoded_tensors), dim=0))) + return torch.cat(encoded_tensors, dim=1).float() @@ -81,23 +118,35 @@ def __init__(self, logical_operator, categories, handle_unknown, infrequent, dev self.num_columns = len(categories) self.handle_unknown = handle_unknown - self.infrequent = infrequent + self.mask = None condition_tensors = [] for arr in categories: condition_tensors.append(torch.nn.Parameter(torch.LongTensor(arr).detach().clone(), requires_grad=False)) self.condition_tensors = torch.nn.ParameterList(condition_tensors) + if infrequent is not None: + infrequent_tensors = [] + for arr in infrequent: + infrequent_tensors.append(torch.nn.Parameter(torch.LongTensor(arr).detach().clone(), requires_grad=False)) + self.infrequent_tensors = torch.nn.ParameterList(infrequent_tensors) + + # We need to create a mask to filter out infrequent categories. + self.mask = torch.nn.ParameterList([]) + for i in range(len(self.condition_tensors[0])): + if self.condition_tensors[0][i] not in self.infrequent_tensors[0]: + self.mask.append(torch.nn.Parameter(self.condition_tensors[0][i], requires_grad=False)) + + else: + self.infrequent_tensors = None + def forward(self, *x): encoded_tensors = [] - - if self.handle_unknown == "ignore": - pass + if self.handle_unknown == "ignore" or self.handle_unknown == "error": # TODO: error + compare_tensors = self.condition_tensors elif self.handle_unknown == "infrequent_if_exist": - pass - elif self.handle_unknown == "error": - pass + compare_tensors = self.mask if self.mask is not None else self.condition_tensors else: raise RuntimeError("Unsupported handle_unknown setting: {0}".format(self.handle_unknown)) @@ -109,7 +158,7 @@ def forward(self, *x): if input.dtype != torch.int64: input = input.long() - encoded_tensors.append(torch.eq(input, self.condition_tensors[i])) + encoded_tensors.append(torch.eq(input, compare_tensors[i])) else: # This is already a tensor. x = x[0] @@ -117,6 +166,10 @@ def forward(self, *x): x = x.long() for i in range(self.num_columns): - encoded_tensors.append(torch.eq(x[:, i : i + 1], self.condition_tensors[i])) + encoded_tensors.append(torch.eq(x[:, i : i + 1], compare_tensors[i])) + + # if self.infrequent_tensors is not None, then append another tensor that is the "not" of the sum of the encoded tensors. + if self.infrequent_tensors is not None: + encoded_tensors.append(torch.logical_not(torch.sum(torch.stack(encoded_tensors), dim=0))) return torch.cat(encoded_tensors, dim=1).float() diff --git a/tests/test_sklearn_one_hot_encoder_converter.py b/tests/test_sklearn_one_hot_encoder_converter.py index 7c909cfbd..9bc54a881 100644 --- a/tests/test_sklearn_one_hot_encoder_converter.py +++ b/tests/test_sklearn_one_hot_encoder_converter.py @@ -10,6 +10,7 @@ from packaging.version import Version, parse + class TestSklearnOneHotEncoderConverter(unittest.TestCase): def test_model_one_hot_encoder_int(self): model = OneHotEncoder() @@ -92,41 +93,70 @@ def test_model_one_hot_encoder_ts_string_not_mod4_len(self): np.testing.assert_allclose(model.transform(data).todense(), pytorch_model.transform(data), rtol=1e-06, atol=1e-06) - @unittest.skipIf(parse(sklearn.__version__) < Version("1.1"), "Skipping test because sklearn version is too old.") - def test_infrequent_if_exists(self): + def test_infrequent_if_exists_str(self): + pass + + # # This test is a copy of the test in sklearn. + # # https://github.com/scikit-learn/scikit-learn/blob/ + # # ecb9a70e82d4ee352e2958c555536a395b53d2bd/sklearn/preprocessing/tests/test_encoders.py#L868 + + # X_train = np.array([["a"] * 5 + ["b"] * 2000 + ["c"] * 10 + ["d"] * 3]).T + # model = OneHotEncoder( + # categories=[["a", "b", "c", "d"]], + # handle_unknown="infrequent_if_exist", + # sparse_output=False, + # min_frequency=15, + + # ).fit(X_train) + # np.testing.assert_array_equal(model.infrequent_categories_, [["a", "c", "d"]]) + + # pytorch_model = hummingbird.ml.convert(model, "torch", device="cpu") + # self.assertIsNotNone(pytorch_model) + + # X_test = [["b"], ["a"], ["c"], ["d"], ["e"]] + # expected = np.array([[1, 0], [0, 1], [0, 1], [0, 1], [0, 1]]) + # orig = model.transform(X_test) + # np.testing.assert_allclose(expected, orig) + + # hb = pytorch_model.transform(X_test) - # This test is a copy of the test in sklearn. - # https://github.com/scikit-learn/scikit-learn/blob/ - # ecb9a70e82d4ee352e2958c555536a395b53d2bd/sklearn/preprocessing/tests/test_encoders.py#L868 + # print("In progress. This is where it fails.") + # print("orig: ", orig) + # print("hb: ", hb) + # np.testing.assert_allclose(orig, hb, rtol=1e-06, atol=1e-06) + # np.testing.assert_allclose(orig.shape, hb.shape, rtol=1e-06, atol=1e-06) - X_train = np.array([["a"] * 5 + ["b"] * 2000 + ["c"] * 10 + ["d"] * 3]).T + @unittest.skipIf(parse(sklearn.__version__) < Version("1.1"), "Skipping test because sklearn version is too old.") + def test_infrequent_if_exists_int(self): + + X_train = np.array([[1] * 5 + [2] * 2000 + [3] * 10 + [4] * 3]).T model = OneHotEncoder( - categories=[["a", "b", "c", "d"]], + categories=[[1, 2, 3, 4]], handle_unknown="infrequent_if_exist", sparse_output=False, min_frequency=15, - ).fit(X_train) - np.testing.assert_array_equal(model.infrequent_categories_, [["a", "c", "d"]]) - + np.testing.assert_array_equal(model.infrequent_categories_, [[1, 3, 4]]) pytorch_model = hummingbird.ml.convert(model, "torch", device="cpu") self.assertIsNotNone(pytorch_model) - X_test = [["b"], ["a"], ["c"], ["d"], ["e"]] + X_test = [[2], [1], [3], [4], [5]] expected = np.array([[1, 0], [0, 1], [0, 1], [0, 1], [0, 1]]) orig = model.transform(X_test) np.testing.assert_allclose(expected, orig) - hb = pytorch_model.transform(X_test) - print("In progress. This is where it fails.") - print("orig: ", orig) - print("hb: ", hb) np.testing.assert_allclose(orig, hb, rtol=1e-06, atol=1e-06) np.testing.assert_allclose(orig.shape, hb.shape, rtol=1e-06, atol=1e-06) + # TODO also hardcode a sample from issue #684 + @unittest.skipIf(parse(sklearn.__version__) < Version("1.1"), "Skipping test because sklearn version is too old.") + def test_user_provided_example(self): + pass + + if __name__ == "__main__": unittest.main() From 0711d50fe0238ebf787abb6d74fd97df8b7c254e Mon Sep 17 00:00:00 2001 From: Karla Saur Date: Wed, 5 Apr 2023 03:52:58 +0000 Subject: [PATCH 05/12] setting optional value for infrequent, moving to end --- .../operator_converters/_one_hot_encoder_implementations.py | 4 ++-- hummingbird/ml/operator_converters/sklearn/one_hot_encoder.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/hummingbird/ml/operator_converters/_one_hot_encoder_implementations.py b/hummingbird/ml/operator_converters/_one_hot_encoder_implementations.py index f67ffd5fe..93763cafe 100644 --- a/hummingbird/ml/operator_converters/_one_hot_encoder_implementations.py +++ b/hummingbird/ml/operator_converters/_one_hot_encoder_implementations.py @@ -22,7 +22,7 @@ class OneHotEncoderString(PhysicalOperator, torch.nn.Module): Because we are dealing with tensors, strings require additional length information for processing. """ - def __init__(self, logical_operator, categories, handle_unknown, infrequent, device, extra_config={}): + def __init__(self, logical_operator, categories, handle_unknown, device, infrequent=None, extra_config={}): super(OneHotEncoderString, self).__init__(logical_operator, transformer=True) self.num_columns = len(categories) @@ -113,7 +113,7 @@ class OneHotEncoder(PhysicalOperator, torch.nn.Module): Class implementing OneHotEncoder operators for ints in PyTorch. """ - def __init__(self, logical_operator, categories, handle_unknown, infrequent, device): + def __init__(self, logical_operator, categories, handle_unknown, device, infrequent=None): super(OneHotEncoder, self).__init__(logical_operator, transformer=True) self.num_columns = len(categories) diff --git a/hummingbird/ml/operator_converters/sklearn/one_hot_encoder.py b/hummingbird/ml/operator_converters/sklearn/one_hot_encoder.py index 3a49ae4e9..5df663694 100644 --- a/hummingbird/ml/operator_converters/sklearn/one_hot_encoder.py +++ b/hummingbird/ml/operator_converters/sklearn/one_hot_encoder.py @@ -50,10 +50,10 @@ def convert_sklearn_one_hot_encoder(operator, device, extra_config): ): categories = [[str(x) for x in c.tolist()] for c in operator.raw_operator.categories_] return OneHotEncoderString(operator, categories, operator.raw_operator.handle_unknown, - infrequent, device, extra_config) + device, infrequent, extra_config) else: return OneHotEncoder(operator, operator.raw_operator.categories_, operator.raw_operator.handle_unknown, - infrequent, device) + device, infrequent) register_converter("SklearnOneHotEncoder", convert_sklearn_one_hot_encoder) From 6f18e2587f939b9ad40633ceebffbdad8200b570 Mon Sep 17 00:00:00 2001 From: Karla Saur Date: Wed, 5 Apr 2023 04:16:43 +0000 Subject: [PATCH 06/12] param ordering with defaults --- .../_one_hot_encoder_implementations.py | 4 ++-- .../ml/operator_converters/sklearn/one_hot_encoder.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/hummingbird/ml/operator_converters/_one_hot_encoder_implementations.py b/hummingbird/ml/operator_converters/_one_hot_encoder_implementations.py index 93763cafe..f66fbc044 100644 --- a/hummingbird/ml/operator_converters/_one_hot_encoder_implementations.py +++ b/hummingbird/ml/operator_converters/_one_hot_encoder_implementations.py @@ -22,7 +22,7 @@ class OneHotEncoderString(PhysicalOperator, torch.nn.Module): Because we are dealing with tensors, strings require additional length information for processing. """ - def __init__(self, logical_operator, categories, handle_unknown, device, infrequent=None, extra_config={}): + def __init__(self, logical_operator, categories, device, extra_config={}, handle_unknown='error', infrequent=None): super(OneHotEncoderString, self).__init__(logical_operator, transformer=True) self.num_columns = len(categories) @@ -113,7 +113,7 @@ class OneHotEncoder(PhysicalOperator, torch.nn.Module): Class implementing OneHotEncoder operators for ints in PyTorch. """ - def __init__(self, logical_operator, categories, handle_unknown, device, infrequent=None): + def __init__(self, logical_operator, categories, device, handle_unknown='error', infrequent=None): super(OneHotEncoder, self).__init__(logical_operator, transformer=True) self.num_columns = len(categories) diff --git a/hummingbird/ml/operator_converters/sklearn/one_hot_encoder.py b/hummingbird/ml/operator_converters/sklearn/one_hot_encoder.py index 5df663694..a51e5dccf 100644 --- a/hummingbird/ml/operator_converters/sklearn/one_hot_encoder.py +++ b/hummingbird/ml/operator_converters/sklearn/one_hot_encoder.py @@ -49,11 +49,11 @@ def convert_sklearn_one_hot_encoder(operator, device, extra_config): ] ): categories = [[str(x) for x in c.tolist()] for c in operator.raw_operator.categories_] - return OneHotEncoderString(operator, categories, operator.raw_operator.handle_unknown, - device, infrequent, extra_config) + return OneHotEncoderString(operator, categories, device, extra_config=extra_config, + handle_unknown=operator.raw_operator.handle_unknown, infrequent=infrequent) else: - return OneHotEncoder(operator, operator.raw_operator.categories_, operator.raw_operator.handle_unknown, - device, infrequent) + return OneHotEncoder(operator, operator.raw_operator.categories_, device, + handle_unknown=operator.raw_operator.handle_unknown, infrequent=infrequent) register_converter("SklearnOneHotEncoder", convert_sklearn_one_hot_encoder) From ed45c21c34e81df257c66b22aa8a12d4d41c950e Mon Sep 17 00:00:00 2001 From: Karla Saur Date: Wed, 5 Apr 2023 04:19:41 +0000 Subject: [PATCH 07/12] param ordering with defaults --- hummingbird/ml/operator_converters/sklearn/one_hot_encoder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hummingbird/ml/operator_converters/sklearn/one_hot_encoder.py b/hummingbird/ml/operator_converters/sklearn/one_hot_encoder.py index a51e5dccf..cdfcae6b2 100644 --- a/hummingbird/ml/operator_converters/sklearn/one_hot_encoder.py +++ b/hummingbird/ml/operator_converters/sklearn/one_hot_encoder.py @@ -50,7 +50,7 @@ def convert_sklearn_one_hot_encoder(operator, device, extra_config): ): categories = [[str(x) for x in c.tolist()] for c in operator.raw_operator.categories_] return OneHotEncoderString(operator, categories, device, extra_config=extra_config, - handle_unknown=operator.raw_operator.handle_unknown, infrequent=infrequent) + handle_unknown=operator.raw_operator.handle_unknown, infrequent=infrequent) else: return OneHotEncoder(operator, operator.raw_operator.categories_, device, handle_unknown=operator.raw_operator.handle_unknown, infrequent=infrequent) From 0539a3104a60d53d453286a6540130f30de6326c Mon Sep 17 00:00:00 2001 From: Karla Saur Date: Wed, 5 Apr 2023 05:32:00 +0000 Subject: [PATCH 08/12] strings passing in some cases, user test still broken --- .../_one_hot_encoder_implementations.py | 9 +-- .../test_sklearn_one_hot_encoder_converter.py | 72 +++++++++++-------- 2 files changed, 48 insertions(+), 33 deletions(-) diff --git a/hummingbird/ml/operator_converters/_one_hot_encoder_implementations.py b/hummingbird/ml/operator_converters/_one_hot_encoder_implementations.py index f66fbc044..68f831923 100644 --- a/hummingbird/ml/operator_converters/_one_hot_encoder_implementations.py +++ b/hummingbird/ml/operator_converters/_one_hot_encoder_implementations.py @@ -74,10 +74,11 @@ def __init__(self, logical_operator, categories, device, extra_config={}, handle self.infrequent_tensors = torch.nn.Parameter(torch.IntTensor(infrequent_tensors), requires_grad=False) # We need to create a mask to filter out infrequent categories. - self.mask = torch.nn.ParameterList([]) - for i in range(len(self.condition_tensors[0])): - if self.condition_tensors[0][i] not in self.infrequent_tensors[0]: - self.mask.append(torch.nn.Parameter(self.condition_tensors[0][i], requires_grad=False)) + mask = [] + for i in range(len(self.condition_tensors)): + if self.condition_tensors[i] not in self.infrequent_tensors: + mask.append(self.condition_tensors[i]) + self.mask = torch.nn.Parameter(torch.tensor([mask]).T, requires_grad=False) else: self.infrequent_tensors = None diff --git a/tests/test_sklearn_one_hot_encoder_converter.py b/tests/test_sklearn_one_hot_encoder_converter.py index 9bc54a881..d2340e875 100644 --- a/tests/test_sklearn_one_hot_encoder_converter.py +++ b/tests/test_sklearn_one_hot_encoder_converter.py @@ -97,38 +97,35 @@ def test_model_one_hot_encoder_ts_string_not_mod4_len(self): def test_infrequent_if_exists_str(self): pass - # # This test is a copy of the test in sklearn. - # # https://github.com/scikit-learn/scikit-learn/blob/ - # # ecb9a70e82d4ee352e2958c555536a395b53d2bd/sklearn/preprocessing/tests/test_encoders.py#L868 + # This test is a copy of the test in sklearn. + # https://github.com/scikit-learn/scikit-learn/blob/ + # ecb9a70e82d4ee352e2958c555536a395b53d2bd/sklearn/preprocessing/tests/test_encoders.py#L868 - # X_train = np.array([["a"] * 5 + ["b"] * 2000 + ["c"] * 10 + ["d"] * 3]).T - # model = OneHotEncoder( - # categories=[["a", "b", "c", "d"]], - # handle_unknown="infrequent_if_exist", - # sparse_output=False, - # min_frequency=15, + X_train = np.array([["a"] * 5 + ["b"] * 2000 + ["c"] * 10 + ["d"] * 3]).T + model = OneHotEncoder( + categories=[["a", "b", "c", "d"]], + handle_unknown="infrequent_if_exist", + sparse_output=False, + min_frequency=15, - # ).fit(X_train) - # np.testing.assert_array_equal(model.infrequent_categories_, [["a", "c", "d"]]) + ).fit(X_train) + np.testing.assert_array_equal(model.infrequent_categories_, [["a", "c", "d"]]) - # pytorch_model = hummingbird.ml.convert(model, "torch", device="cpu") - # self.assertIsNotNone(pytorch_model) + pytorch_model = hummingbird.ml.convert(model, "torch", device="cpu") + self.assertIsNotNone(pytorch_model) - # X_test = [["b"], ["a"], ["c"], ["d"], ["e"]] - # expected = np.array([[1, 0], [0, 1], [0, 1], [0, 1], [0, 1]]) - # orig = model.transform(X_test) - # np.testing.assert_allclose(expected, orig) + X_test = [["b"], ["a"], ["c"], ["d"], ["e"]] + expected = np.array([[1, 0], [0, 1], [0, 1], [0, 1], [0, 1]]) + orig = model.transform(X_test) + np.testing.assert_allclose(expected, orig) - # hb = pytorch_model.transform(X_test) + hb = pytorch_model.transform(X_test) - # print("In progress. This is where it fails.") - # print("orig: ", orig) - # print("hb: ", hb) - # np.testing.assert_allclose(orig, hb, rtol=1e-06, atol=1e-06) - # np.testing.assert_allclose(orig.shape, hb.shape, rtol=1e-06, atol=1e-06) + np.testing.assert_allclose(orig, hb, rtol=1e-06, atol=1e-06) + np.testing.assert_allclose(orig.shape, hb.shape, rtol=1e-06, atol=1e-06) - @unittest.skipIf(parse(sklearn.__version__) < Version("1.1"), "Skipping test because sklearn version is too old.") - def test_infrequent_if_exists_int(self): + # @unittest.skipIf(parse(sklearn.__version__) < Version("1.1"), "Skipping test because sklearn version is too old.") + # def test_infrequent_if_exists_int(self): X_train = np.array([[1] * 5 + [2] * 2000 + [3] * 10 + [4] * 3]).T model = OneHotEncoder( @@ -152,10 +149,27 @@ def test_infrequent_if_exists_int(self): np.testing.assert_allclose(orig, hb, rtol=1e-06, atol=1e-06) np.testing.assert_allclose(orig.shape, hb.shape, rtol=1e-06, atol=1e-06) - # TODO also hardcode a sample from issue #684 - @unittest.skipIf(parse(sklearn.__version__) < Version("1.1"), "Skipping test because sklearn version is too old.") - def test_user_provided_example(self): - pass + # @unittest.skipIf(parse(sklearn.__version__) < Version("1.1"), "Skipping test because sklearn version is too old.") + # def test_user_provided_example(self): + + # from sklearn.impute import SimpleImputer + # from sklearn.pipeline import Pipeline + + # X_train = np.array([[22.0, 1.0, 0.0, 1251.0, 123.0, 124.0, 123.0, 0, 0, 0, 0, 0, 0, 0, 0] * 10 + # + [10.0, 1.0, 0.0, 1251.0, 123.0, 124.0, 134.0, 0, 0, 0, 0, 0, 0, 0, 0] * 2 + # + [14.0, 1.0, 0.0, 1251.0, 123.0, 124.0, 134.0, 0, 0, 0, 0, 0, 0, 0, 0] * 3 + # + [12.0, 2.0, 0.0, 1251.0, 123.0, 124.0, 134.0, 0, 0, 0, 0, 0, 0, 0, 0] * 1]) + # pipe = Pipeline( + # [ + # ("imputer", SimpleImputer(strategy="most_frequent")), + # ("encoder", OneHotEncoder(sparse_output=False, handle_unknown="infrequent_if_exist", min_frequency=9)), + # ], + # verbose=True, + # ).fit(X_train) + + # hb = hummingbird.ml.convert(pipe, "pytorch", device="cpu") + + # np.testing.assert_allclose(pipe.transform(X_train), hb.transform(X_train), rtol=1e-06, atol=1e-06) if __name__ == "__main__": From c888458abfe558a37ab381ce6dd687fe252a763e Mon Sep 17 00:00:00 2001 From: Karla Saur Date: Mon, 10 Apr 2023 21:36:22 +0000 Subject: [PATCH 09/12] minor cleanup --- .../_one_hot_encoder_implementations.py | 14 ++++++++------ .../test_sklearn_one_hot_encoder_converter.py | 19 ++++++++++++++++--- 2 files changed, 24 insertions(+), 9 deletions(-) diff --git a/hummingbird/ml/operator_converters/_one_hot_encoder_implementations.py b/hummingbird/ml/operator_converters/_one_hot_encoder_implementations.py index 68f831923..cf6ee5a59 100644 --- a/hummingbird/ml/operator_converters/_one_hot_encoder_implementations.py +++ b/hummingbird/ml/operator_converters/_one_hot_encoder_implementations.py @@ -132,12 +132,14 @@ def __init__(self, logical_operator, categories, device, handle_unknown='error', infrequent_tensors.append(torch.nn.Parameter(torch.LongTensor(arr).detach().clone(), requires_grad=False)) self.infrequent_tensors = torch.nn.ParameterList(infrequent_tensors) - # We need to create a mask to filter out infrequent categories. - self.mask = torch.nn.ParameterList([]) - for i in range(len(self.condition_tensors[0])): - if self.condition_tensors[0][i] not in self.infrequent_tensors[0]: - self.mask.append(torch.nn.Parameter(self.condition_tensors[0][i], requires_grad=False)) - + # Filter out infrequent categories by creating a mask + self.mask = [] + for i in range(len(self.condition_tensors)): + row_mask = [] + for j in range(len(self.infrequent_tensors[0])): + if self.condition_tensors[i][j] not in self.infrequent_tensors[i]: + row_mask.append(self.condition_tensors[i][j]) + self.mask.append(torch.nn.Parameter(torch.tensor(row_mask), requires_grad=False)) else: self.infrequent_tensors = None diff --git a/tests/test_sklearn_one_hot_encoder_converter.py b/tests/test_sklearn_one_hot_encoder_converter.py index d2340e875..39613294c 100644 --- a/tests/test_sklearn_one_hot_encoder_converter.py +++ b/tests/test_sklearn_one_hot_encoder_converter.py @@ -95,7 +95,6 @@ def test_model_one_hot_encoder_ts_string_not_mod4_len(self): @unittest.skipIf(parse(sklearn.__version__) < Version("1.1"), "Skipping test because sklearn version is too old.") def test_infrequent_if_exists_str(self): - pass # This test is a copy of the test in sklearn. # https://github.com/scikit-learn/scikit-learn/blob/ @@ -124,8 +123,8 @@ def test_infrequent_if_exists_str(self): np.testing.assert_allclose(orig, hb, rtol=1e-06, atol=1e-06) np.testing.assert_allclose(orig.shape, hb.shape, rtol=1e-06, atol=1e-06) - # @unittest.skipIf(parse(sklearn.__version__) < Version("1.1"), "Skipping test because sklearn version is too old.") - # def test_infrequent_if_exists_int(self): + @unittest.skipIf(parse(sklearn.__version__) < Version("1.1"), "Skipping test because sklearn version is too old.") + def test_infrequent_if_exists_int(self): X_train = np.array([[1] * 5 + [2] * 2000 + [3] * 10 + [4] * 3]).T model = OneHotEncoder( @@ -149,6 +148,20 @@ def test_infrequent_if_exists_str(self): np.testing.assert_allclose(orig, hb, rtol=1e-06, atol=1e-06) np.testing.assert_allclose(orig.shape, hb.shape, rtol=1e-06, atol=1e-06) + # @unittest.skipIf(parse(sklearn.__version__) < Version("1.1"), "Skipping test because sklearn version is too old.") + # def test_2d_infrequent(self): + + # X_train = np.array([[10.0, 1.0]] * 10 + # + [[22.0, 1.0]] * 2 + # + [[10.0, 1.0]] * 3 + # + [[14.0, 2.0]] * 1) + # ohe = OneHotEncoder(sparse_output=False, handle_unknown="infrequent_if_exist", min_frequency=0.49).fit(X_train) + # #ohe = OneHotEncoder(sparse_output=False).fit(X_train) + + # hb = hummingbird.ml.convert(ohe, "pytorch", device="cpu") + + # np.testing.assert_allclose(ohe.transform(X_train), hb.transform(X_train), rtol=1e-06, atol=1e-06) + # @unittest.skipIf(parse(sklearn.__version__) < Version("1.1"), "Skipping test because sklearn version is too old.") # def test_user_provided_example(self): From b43ec45cb2a0b5317673ba8705897ecb253f881b Mon Sep 17 00:00:00 2001 From: Karla Saur Date: Mon, 10 Apr 2023 23:14:05 +0000 Subject: [PATCH 10/12] fixed unkown category column --- .../_one_hot_encoder_implementations.py | 10 ++++++---- tests/test_sklearn_one_hot_encoder_converter.py | 16 ++++++---------- 2 files changed, 12 insertions(+), 14 deletions(-) diff --git a/hummingbird/ml/operator_converters/_one_hot_encoder_implementations.py b/hummingbird/ml/operator_converters/_one_hot_encoder_implementations.py index cf6ee5a59..64efe073f 100644 --- a/hummingbird/ml/operator_converters/_one_hot_encoder_implementations.py +++ b/hummingbird/ml/operator_converters/_one_hot_encoder_implementations.py @@ -169,10 +169,12 @@ def forward(self, *x): x = x.long() for i in range(self.num_columns): - encoded_tensors.append(torch.eq(x[:, i : i + 1], compare_tensors[i])) + curr_column = torch.eq(x[:, i : i + 1], compare_tensors[i]) + encoded_tensors.append(curr_column) - # if self.infrequent_tensors is not None, then append another tensor that is the "not" of the sum of the encoded tensors. - if self.infrequent_tensors is not None: - encoded_tensors.append(torch.logical_not(torch.sum(torch.stack(encoded_tensors), dim=0))) + # If self.infrequent_tensors is not None, then append another tensor that is + # the logical "not" of the sum of the encoded tensors of the *current* iteration only + if self.infrequent_tensors is not None: + encoded_tensors.append(torch.logical_not(torch.sum(torch.stack([curr_column]), dim=0))) return torch.cat(encoded_tensors, dim=1).float() diff --git a/tests/test_sklearn_one_hot_encoder_converter.py b/tests/test_sklearn_one_hot_encoder_converter.py index 39613294c..27734bf43 100644 --- a/tests/test_sklearn_one_hot_encoder_converter.py +++ b/tests/test_sklearn_one_hot_encoder_converter.py @@ -148,19 +148,15 @@ def test_infrequent_if_exists_int(self): np.testing.assert_allclose(orig, hb, rtol=1e-06, atol=1e-06) np.testing.assert_allclose(orig.shape, hb.shape, rtol=1e-06, atol=1e-06) - # @unittest.skipIf(parse(sklearn.__version__) < Version("1.1"), "Skipping test because sklearn version is too old.") - # def test_2d_infrequent(self): + @unittest.skipIf(parse(sklearn.__version__) < Version("1.1"), "Skipping test because sklearn version is too old.") + def test_2d_infrequent(self): - # X_train = np.array([[10.0, 1.0]] * 10 - # + [[22.0, 1.0]] * 2 - # + [[10.0, 1.0]] * 3 - # + [[14.0, 2.0]] * 1) - # ohe = OneHotEncoder(sparse_output=False, handle_unknown="infrequent_if_exist", min_frequency=0.49).fit(X_train) - # #ohe = OneHotEncoder(sparse_output=False).fit(X_train) + X_train = np.array([[10.0, 1.0]] * 3 + [[14.0, 2.0]] * 2) + ohe = OneHotEncoder(sparse_output=False, handle_unknown="infrequent_if_exist", min_frequency=0.49).fit(X_train) - # hb = hummingbird.ml.convert(ohe, "pytorch", device="cpu") + hb = hummingbird.ml.convert(ohe, "pytorch", device="cpu") - # np.testing.assert_allclose(ohe.transform(X_train), hb.transform(X_train), rtol=1e-06, atol=1e-06) + np.testing.assert_allclose(ohe.transform(X_train), hb.transform(X_train), rtol=1e-06, atol=1e-06) # @unittest.skipIf(parse(sklearn.__version__) < Version("1.1"), "Skipping test because sklearn version is too old.") # def test_user_provided_example(self): From 8873aceabd1aadf20922bfff867414a5cc2ff6d3 Mon Sep 17 00:00:00 2001 From: Karla Saur Date: Mon, 10 Apr 2023 23:17:01 +0000 Subject: [PATCH 11/12] one more check --- tests/test_sklearn_one_hot_encoder_converter.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/test_sklearn_one_hot_encoder_converter.py b/tests/test_sklearn_one_hot_encoder_converter.py index 27734bf43..9746173f8 100644 --- a/tests/test_sklearn_one_hot_encoder_converter.py +++ b/tests/test_sklearn_one_hot_encoder_converter.py @@ -156,8 +156,14 @@ def test_2d_infrequent(self): hb = hummingbird.ml.convert(ohe, "pytorch", device="cpu") + # Quick check on a dataset where all values have been seen during training np.testing.assert_allclose(ohe.transform(X_train), hb.transform(X_train), rtol=1e-06, atol=1e-06) + # Now check data not seen during training + X_test = np.array([[10.0, 1.0]] * 3 + [[14.0, 3.0]] * 2) + np.testing.assert_allclose(ohe.transform(X_test), hb.transform(X_test), rtol=1e-06, atol=1e-06) + + # @unittest.skipIf(parse(sklearn.__version__) < Version("1.1"), "Skipping test because sklearn version is too old.") # def test_user_provided_example(self): From afff99ae3a33a3598ea3741336a32b66322a9440 Mon Sep 17 00:00:00 2001 From: Karla Saur Date: Mon, 10 Apr 2023 23:38:15 +0000 Subject: [PATCH 12/12] user test still failing --- tests/test_sklearn_one_hot_encoder_converter.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_sklearn_one_hot_encoder_converter.py b/tests/test_sklearn_one_hot_encoder_converter.py index 9746173f8..91d6dae2e 100644 --- a/tests/test_sklearn_one_hot_encoder_converter.py +++ b/tests/test_sklearn_one_hot_encoder_converter.py @@ -163,9 +163,9 @@ def test_2d_infrequent(self): X_test = np.array([[10.0, 1.0]] * 3 + [[14.0, 3.0]] * 2) np.testing.assert_allclose(ohe.transform(X_test), hb.transform(X_test), rtol=1e-06, atol=1e-06) - - # @unittest.skipIf(parse(sklearn.__version__) < Version("1.1"), "Skipping test because sklearn version is too old.") - # def test_user_provided_example(self): + @unittest.skipIf(parse(sklearn.__version__) < Version("1.1"), "Skipping test because sklearn version is too old.") + def test_user_provided_example(self): + pass # from sklearn.impute import SimpleImputer # from sklearn.pipeline import Pipeline