Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,13 @@ class OneHotEncoderString(PhysicalOperator, torch.nn.Module):
Because we are dealing with tensors, strings require additional length information for processing.
"""

def __init__(self, logical_operator, categories, device, extra_config={}):
def __init__(self, logical_operator, categories, device, extra_config={}, handle_unknown='error', infrequent=None):
super(OneHotEncoderString, self).__init__(logical_operator, transformer=True)

self.num_columns = len(categories)
self.max_word_length = max([max([len(c) for c in cat]) for cat in categories])
self.handle_unknown = handle_unknown
self.mask = None

# Strings are casted to int32, therefore we need to properly size the tensor to me dividable by 4.
while self.max_word_length % 4 != 0:
Expand Down Expand Up @@ -55,17 +57,55 @@ def __init__(self, logical_operator, categories, device, extra_config={}):
self.condition_tensors = torch.nn.Parameter(torch.IntTensor(condition_tensors), requires_grad=False)
self.categories_idx = categories_idx

if infrequent is not None:
infrequent_tensors = []
categories_idx = [0]
for arr in infrequent:
cats = (
np.array(arr, dtype="|S" + str(self.max_word_length)) # Encode objects into 4 byte strings.
.view("int32")
.reshape(-1, self.max_word_length // 4)
.tolist()
)
# We merge all categories for all columns into a single tensor
infrequent_tensors.extend(cats)
# Since all categories are merged together, we need to track of indexes to retrieve them at inference time.
categories_idx.append(categories_idx[-1] + len(cats))
self.infrequent_tensors = torch.nn.Parameter(torch.IntTensor(infrequent_tensors), requires_grad=False)

# We need to create a mask to filter out infrequent categories.
mask = []
for i in range(len(self.condition_tensors)):
if self.condition_tensors[i] not in self.infrequent_tensors:
mask.append(self.condition_tensors[i])
self.mask = torch.nn.Parameter(torch.tensor([mask]).T, requires_grad=False)
else:
self.infrequent_tensors = None

def forward(self, x):
encoded_tensors = []

# TODO: implement 'error' case separately
if self.handle_unknown == "ignore" or self.handle_unknown == "error":
compare_tensors = self.condition_tensors
elif self.handle_unknown == "infrequent_if_exist":
compare_tensors = self.mask if self.mask is not None else self.condition_tensors
else:
raise RuntimeError("Unsupported handle_unknown setting: {0}".format(self.handle_unknown))

for i in range(self.num_columns):
# First we fetch the condition for the particular column.
conditions = self.condition_tensors[self.categories_idx[i] : self.categories_idx[i + 1], :].view(
conditions = compare_tensors[self.categories_idx[i] : self.categories_idx[i + 1], :].view(
1, -1, self.max_word_length // 4
)
# Differently than the numeric case where eq is enough, here we need to aggregate per object (dim = 2)
# because objects can span multiple integers. We use product here since all ints must match to get encoding of 1.
encoded_tensors.append(torch.prod(torch.eq(x[:, i : i + 1, :], conditions), dim=2))

# if self.infrequent_tensors is not None, then append another tensor that is the "not" of the sum of the encoded tensors.
if self.infrequent_tensors is not None:
encoded_tensors.append(torch.logical_not(torch.sum(torch.stack(encoded_tensors), dim=0)))

return torch.cat(encoded_tensors, dim=1).float()


Expand All @@ -74,19 +114,45 @@ class OneHotEncoder(PhysicalOperator, torch.nn.Module):
Class implementing OneHotEncoder operators for ints in PyTorch.
"""

def __init__(self, logical_operator, categories, device):
def __init__(self, logical_operator, categories, device, handle_unknown='error', infrequent=None):
super(OneHotEncoder, self).__init__(logical_operator, transformer=True)

self.num_columns = len(categories)
self.handle_unknown = handle_unknown
self.mask = None

condition_tensors = []
for arr in categories:
condition_tensors.append(torch.nn.Parameter(torch.LongTensor(arr).detach().clone(), requires_grad=False))
self.condition_tensors = torch.nn.ParameterList(condition_tensors)

if infrequent is not None:
infrequent_tensors = []
for arr in infrequent:
infrequent_tensors.append(torch.nn.Parameter(torch.LongTensor(arr).detach().clone(), requires_grad=False))
self.infrequent_tensors = torch.nn.ParameterList(infrequent_tensors)

# Filter out infrequent categories by creating a mask
self.mask = []
for i in range(len(self.condition_tensors)):
row_mask = []
for j in range(len(self.infrequent_tensors[0])):
if self.condition_tensors[i][j] not in self.infrequent_tensors[i]:
row_mask.append(self.condition_tensors[i][j])
self.mask.append(torch.nn.Parameter(torch.tensor(row_mask), requires_grad=False))
else:
self.infrequent_tensors = None

def forward(self, *x):
encoded_tensors = []

if self.handle_unknown == "ignore" or self.handle_unknown == "error": # TODO: error
compare_tensors = self.condition_tensors
elif self.handle_unknown == "infrequent_if_exist":
compare_tensors = self.mask if self.mask is not None else self.condition_tensors
else:
raise RuntimeError("Unsupported handle_unknown setting: {0}".format(self.handle_unknown))

if len(x) > 1:
assert len(x) == self.num_columns

Expand All @@ -95,14 +161,20 @@ def forward(self, *x):
if input.dtype != torch.int64:
input = input.long()

encoded_tensors.append(torch.eq(input, self.condition_tensors[i]))
encoded_tensors.append(torch.eq(input, compare_tensors[i]))
else:
# This is already a tensor.
x = x[0]
if x.dtype != torch.int64:
x = x.long()

for i in range(self.num_columns):
encoded_tensors.append(torch.eq(x[:, i : i + 1], self.condition_tensors[i]))
curr_column = torch.eq(x[:, i : i + 1], compare_tensors[i])
encoded_tensors.append(curr_column)

# If self.infrequent_tensors is not None, then append another tensor that is
# the logical "not" of the sum of the encoded tensors of the *current* iteration only
if self.infrequent_tensors is not None:
encoded_tensors.append(torch.logical_not(torch.sum(torch.stack([curr_column]), dim=0)))

return torch.cat(encoded_tensors, dim=1).float()
19 changes: 17 additions & 2 deletions hummingbird/ml/operator_converters/sklearn/one_hot_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,31 @@ def convert_sklearn_one_hot_encoder(operator, device, extra_config):
"""
assert operator is not None, "Cannot convert None operator"

# scikit-learn >= 1.1 with handle_unknown = 'frequent_if_exist'
if hasattr(operator.raw_operator, "infrequent_categories_"):
infrequent = operator.raw_operator.infrequent_categories_
else:
infrequent = None

# TODO: What to do about min_frequency and max_categories?
# If I understand correctly, they are only used prior to "fit", and we won't need them for inference.
# Both min_frequency and max_categories trigger the creation of the "infrequent" categories, but then
# are not used again. So, we can ignore them for HB....i think?
# see https://github.com/scikit-learn/scikit-learn/blob/main/sklearn/preprocessing/_encoders.py#L178
# and the comment on line 503 same file.

if all(
[
np.array(c).dtype == object or np.array(c).dtype.kind in constants.SUPPORTED_STRING_TYPES
for c in operator.raw_operator.categories_
]
):
categories = [[str(x) for x in c.tolist()] for c in operator.raw_operator.categories_]
return OneHotEncoderString(operator, categories, device, extra_config)
return OneHotEncoderString(operator, categories, device, extra_config=extra_config,
handle_unknown=operator.raw_operator.handle_unknown, infrequent=infrequent)
else:
return OneHotEncoder(operator, operator.raw_operator.categories_, device)
return OneHotEncoder(operator, operator.raw_operator.categories_, device,
handle_unknown=operator.raw_operator.handle_unknown, infrequent=infrequent)


register_converter("SklearnOneHotEncoder", convert_sklearn_one_hot_encoder)
97 changes: 96 additions & 1 deletion tests/test_sklearn_one_hot_encoder_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
import unittest

import numpy as np
import torch
import sklearn
from sklearn.preprocessing import OneHotEncoder
import hummingbird.ml

from packaging.version import Version, parse


class TestSklearnOneHotEncoderConverter(unittest.TestCase):
def test_model_one_hot_encoder_int(self):
Expand Down Expand Up @@ -91,6 +93,99 @@ def test_model_one_hot_encoder_ts_string_not_mod4_len(self):

np.testing.assert_allclose(model.transform(data).todense(), pytorch_model.transform(data), rtol=1e-06, atol=1e-06)

@unittest.skipIf(parse(sklearn.__version__) < Version("1.1"), "Skipping test because sklearn version is too old.")
def test_infrequent_if_exists_str(self):

# This test is a copy of the test in sklearn.
# https://github.com/scikit-learn/scikit-learn/blob/
# ecb9a70e82d4ee352e2958c555536a395b53d2bd/sklearn/preprocessing/tests/test_encoders.py#L868

X_train = np.array([["a"] * 5 + ["b"] * 2000 + ["c"] * 10 + ["d"] * 3]).T
model = OneHotEncoder(
categories=[["a", "b", "c", "d"]],
handle_unknown="infrequent_if_exist",
sparse_output=False,
min_frequency=15,

).fit(X_train)
np.testing.assert_array_equal(model.infrequent_categories_, [["a", "c", "d"]])

pytorch_model = hummingbird.ml.convert(model, "torch", device="cpu")
self.assertIsNotNone(pytorch_model)

X_test = [["b"], ["a"], ["c"], ["d"], ["e"]]
expected = np.array([[1, 0], [0, 1], [0, 1], [0, 1], [0, 1]])
orig = model.transform(X_test)
np.testing.assert_allclose(expected, orig)

hb = pytorch_model.transform(X_test)

np.testing.assert_allclose(orig, hb, rtol=1e-06, atol=1e-06)
np.testing.assert_allclose(orig.shape, hb.shape, rtol=1e-06, atol=1e-06)

@unittest.skipIf(parse(sklearn.__version__) < Version("1.1"), "Skipping test because sklearn version is too old.")
def test_infrequent_if_exists_int(self):

X_train = np.array([[1] * 5 + [2] * 2000 + [3] * 10 + [4] * 3]).T
model = OneHotEncoder(
categories=[[1, 2, 3, 4]],
handle_unknown="infrequent_if_exist",
sparse_output=False,
min_frequency=15,
).fit(X_train)
np.testing.assert_array_equal(model.infrequent_categories_, [[1, 3, 4]])

pytorch_model = hummingbird.ml.convert(model, "torch", device="cpu")
self.assertIsNotNone(pytorch_model)

X_test = [[2], [1], [3], [4], [5]]
expected = np.array([[1, 0], [0, 1], [0, 1], [0, 1], [0, 1]])
orig = model.transform(X_test)
np.testing.assert_allclose(expected, orig)

hb = pytorch_model.transform(X_test)

np.testing.assert_allclose(orig, hb, rtol=1e-06, atol=1e-06)
np.testing.assert_allclose(orig.shape, hb.shape, rtol=1e-06, atol=1e-06)

@unittest.skipIf(parse(sklearn.__version__) < Version("1.1"), "Skipping test because sklearn version is too old.")
def test_2d_infrequent(self):

X_train = np.array([[10.0, 1.0]] * 3 + [[14.0, 2.0]] * 2)
ohe = OneHotEncoder(sparse_output=False, handle_unknown="infrequent_if_exist", min_frequency=0.49).fit(X_train)

hb = hummingbird.ml.convert(ohe, "pytorch", device="cpu")

# Quick check on a dataset where all values have been seen during training
np.testing.assert_allclose(ohe.transform(X_train), hb.transform(X_train), rtol=1e-06, atol=1e-06)

# Now check data not seen during training
X_test = np.array([[10.0, 1.0]] * 3 + [[14.0, 3.0]] * 2)
np.testing.assert_allclose(ohe.transform(X_test), hb.transform(X_test), rtol=1e-06, atol=1e-06)

@unittest.skipIf(parse(sklearn.__version__) < Version("1.1"), "Skipping test because sklearn version is too old.")
def test_user_provided_example(self):
pass

# from sklearn.impute import SimpleImputer
# from sklearn.pipeline import Pipeline

# X_train = np.array([[22.0, 1.0, 0.0, 1251.0, 123.0, 124.0, 123.0, 0, 0, 0, 0, 0, 0, 0, 0] * 10
# + [10.0, 1.0, 0.0, 1251.0, 123.0, 124.0, 134.0, 0, 0, 0, 0, 0, 0, 0, 0] * 2
# + [14.0, 1.0, 0.0, 1251.0, 123.0, 124.0, 134.0, 0, 0, 0, 0, 0, 0, 0, 0] * 3
# + [12.0, 2.0, 0.0, 1251.0, 123.0, 124.0, 134.0, 0, 0, 0, 0, 0, 0, 0, 0] * 1])
# pipe = Pipeline(
# [
# ("imputer", SimpleImputer(strategy="most_frequent")),
# ("encoder", OneHotEncoder(sparse_output=False, handle_unknown="infrequent_if_exist", min_frequency=9)),
# ],
# verbose=True,
# ).fit(X_train)

# hb = hummingbird.ml.convert(pipe, "pytorch", device="cpu")

# np.testing.assert_allclose(pipe.transform(X_train), hb.transform(X_train), rtol=1e-06, atol=1e-06)


if __name__ == "__main__":
unittest.main()