From 4678775fb6412234237e6efeed7108b7dd19a6db Mon Sep 17 00:00:00 2001 From: BenjaminBossan Date: Sun, 13 Nov 2016 18:49:56 +0100 Subject: [PATCH 1/3] First attempt, basics work, base.py not touched. --- nolearn/lasagne/tests/test_util.py | 179 +++++++++++++++++++++++++++++ nolearn/lasagne/util.py | 53 +++++++++ 2 files changed, 232 insertions(+) create mode 100644 nolearn/lasagne/tests/test_util.py diff --git a/nolearn/lasagne/tests/test_util.py b/nolearn/lasagne/tests/test_util.py new file mode 100644 index 0000000..fec1403 --- /dev/null +++ b/nolearn/lasagne/tests/test_util.py @@ -0,0 +1,179 @@ +import numpy as np +import pytest + + +class TestSliceDict: + def assert_dicts_equal(self, d0, d1): + assert d0.keys() == d1.keys() + for key in d0.keys(): + assert np.allclose(d0[key], d1[key]) + + @pytest.fixture(scope='session') + def SliceDict(self): + from nolearn.lasagne.util import SliceDict + return SliceDict + + @pytest.fixture + def sldict(self, SliceDict): + return SliceDict( + f0=np.arange(4), + f1=np.arange(12).reshape(4, 3), + ) + + def test_init_inconsistent_shapes(self, SliceDict): + with pytest.raises(ValueError) as exc: + SliceDict(f0=np.ones((10, 5)), f1=np.ones((11, 5))) + assert str(exc.value) == ( + "Initialized with items of different shapes: 10, 11") + + @pytest.mark.parametrize('item', [ + np.ones(4), + np.ones((4, 1)), + np.ones((4, 4)), + np.ones((4, 10, 7)), + np.ones((4, 1, 28, 28)), + ]) + def test_set_item_correct_shape(self, sldict, item): + # does not raise + sldict['f2'] = item + + @pytest.mark.parametrize('item', [ + np.ones(3), + np.ones((1, 100)), + np.ones((5, 1000)), + np.ones((1, 100, 10)), + np.ones((28, 28, 1, 100)), + ]) + def test_set_item_incorrect_shape_raises(self, sldict, item): + with pytest.raises(ValueError) as exc: + sldict['f2'] = item + assert str(exc.value) == ( + "Cannot set array with shape[0] != 4") + + @pytest.mark.parametrize('key', [1, 1.2, (1, 2), [3]]) + def test_set_item_incorrect_key_type(self, sldict, key): + with pytest.raises(TypeError) as exc: + sldict[key] = np.ones((100, 5)) + assert str(exc.value).startswith("Key must be str, not 0.85 diff --git a/nolearn/lasagne/util.py b/nolearn/lasagne/util.py index 6c7199b..94744d8 100644 --- a/nolearn/lasagne/util.py +++ b/nolearn/lasagne/util.py @@ -7,6 +7,9 @@ import numpy as np from tabulate import tabulate +from nolearn._compat import basestring + + convlayers = [Conv2DLayer] maxpoollayers = [MaxPool2DLayer] try: @@ -180,3 +183,53 @@ def get_conv_infos(net, min_capacity=100. / 6, detailed=False): receptive_fields.astype(int))) return tabulate(table, header, floatfmt='.2f') + + +class SliceDict(dict, object): + def __init__(self, **kwargs): + shapes = [value.shape[0] for value in kwargs.values()] + shapes_set = set(shapes) + if shapes_set and (len(shapes_set) != 1): + raise ValueError( + "Initialized with items of different shapes: {}" + "".format(', '.join(map(str, sorted(shapes_set))))) + + if not shapes: + self._len = 0 + else: + self._len = shapes[0] + + super(SliceDict, self).__init__(**kwargs) + + def __len__(self): + return self._len + + def __getitem__(self, sl): + # if isinstance(sl, int): + # raise TypeError("Not sure what to do here?!") + if isinstance(sl, basestring): + return super(SliceDict, self).__getitem__(sl) + return SliceDict(**{k: v[sl] for k, v in self.items()}) + + def __setitem__(self, key, value): + if not isinstance(key, basestring): + raise TypeError("Key must be str, not {}.".format(type(key))) + + length = value.shape[0] + if len(self.keys()) == 0: + self._len = length + + if self._len != length: + raise ValueError( + "Cannot set array with shape[0] != {}" + "".format(self._len)) + + super(SliceDict, self).__setitem__(key, value) + + def update(self, kwargs): + for key, value in kwargs.items(): + self.__setitem__(key, value) + + def __repr__(self): + out = super(SliceDict, self).__repr__() + return "SliceDict(**{})".format(out) From c75f0cdcb88752f4a1bd5db27054f9d3429e7555 Mon Sep 17 00:00:00 2001 From: BenjaminBossan Date: Mon, 14 Nov 2016 20:27:43 +0100 Subject: [PATCH 2/3] Make tests compatible with python 3.4, add more tests. --- nolearn/lasagne/tests/test_util.py | 27 ++++++++++++++++++++------- 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/nolearn/lasagne/tests/test_util.py b/nolearn/lasagne/tests/test_util.py index fec1403..c2d7f63 100644 --- a/nolearn/lasagne/tests/test_util.py +++ b/nolearn/lasagne/tests/test_util.py @@ -54,7 +54,7 @@ def test_set_item_incorrect_shape_raises(self, sldict, item): def test_set_item_incorrect_key_type(self, sldict, key): with pytest.raises(TypeError) as exc: sldict[key] = np.ones((100, 5)) - assert str(exc.value).startswith("Key must be str, not Date: Thu, 17 Nov 2016 22:31:53 +0100 Subject: [PATCH 3/3] SliceDict raises error when used with int, no longer inherit from object. --- nolearn/lasagne/tests/test_util.py | 8 +++++--- nolearn/lasagne/util.py | 6 +++--- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/nolearn/lasagne/tests/test_util.py b/nolearn/lasagne/tests/test_util.py index c2d7f63..2fc5129 100644 --- a/nolearn/lasagne/tests/test_util.py +++ b/nolearn/lasagne/tests/test_util.py @@ -119,6 +119,11 @@ def test_slice_mask(self, sldict, SliceDict): f1=np.array([[0, 1, 2], [6, 7, 8]])) self.assert_dicts_equal(result, expected) + def test_slice_int(self, sldict): + with pytest.raises(ValueError) as exc: + sldict[0] + assert str(exc.value) == 'SliceDict cannot be indexed by integers.' + def test_len_sliced(self, sldict): assert len(sldict) == 4 for i in range(1, 4): @@ -137,9 +142,6 @@ def test_iter(self, sldict): expected_keys.remove(key) assert not expected_keys - def test_slice_int(self, sldict): - pass - @pytest.fixture(scope='session') def net(self, NeuralNet): from lasagne.layers import ConcatLayer, DenseLayer, InputLayer diff --git a/nolearn/lasagne/util.py b/nolearn/lasagne/util.py index 94744d8..8a1991e 100644 --- a/nolearn/lasagne/util.py +++ b/nolearn/lasagne/util.py @@ -185,7 +185,7 @@ def get_conv_infos(net, min_capacity=100. / 6, detailed=False): return tabulate(table, header, floatfmt='.2f') -class SliceDict(dict, object): +class SliceDict(dict): def __init__(self, **kwargs): shapes = [value.shape[0] for value in kwargs.values()] shapes_set = set(shapes) @@ -205,8 +205,8 @@ def __len__(self): return self._len def __getitem__(self, sl): - # if isinstance(sl, int): - # raise TypeError("Not sure what to do here?!") + if isinstance(sl, int): + raise ValueError("SliceDict cannot be indexed by integers.") if isinstance(sl, basestring): return super(SliceDict, self).__getitem__(sl) return SliceDict(**{k: v[sl] for k, v in self.items()})