diff --git a/nolearn/lasagne/base.py b/nolearn/lasagne/base.py index fd54fe1..f794b1c 100644 --- a/nolearn/lasagne/base.py +++ b/nolearn/lasagne/base.py @@ -2,6 +2,7 @@ from .._compat import pickle from collections import OrderedDict +from difflib import SequenceMatcher import functools import itertools import operator @@ -395,28 +396,55 @@ def get_all_params(self): params = sum([l.get_params() for l in layers], []) return unique(params) - def load_weights_from(self, source): - self.initialize() - - if isinstance(source, str): - source = np.load(source) - - if isinstance(source, NeuralNet): - source = source.get_all_params() - - source_weights = [ - w.get_value() if hasattr(w, 'get_value') else w for w in source] - - for w1, w2 in zip(source_weights, self.get_all_params()): - if w1.shape != w2.get_value().shape: - continue - w2.set_value(w1) - def save_weights_to(self, fname): weights = [w.get_value() for w in self.get_all_params()] with open(fname, 'wb') as f: pickle.dump(weights, f, -1) + @staticmethod + def _param_alignment(shapes0, shapes1): + shapes0 = list(map(str, shapes0)) + shapes1 = list(map(str, shapes1)) + matcher = SequenceMatcher(a=shapes0, b=shapes1) + matches = [] + for block in matcher.get_matching_blocks(): + if block.size == 0: + continue + matches.append((list(range(block.a, block.a + block.size)), + list(range(block.b, block.b + block.size)))) + result = [line for match in matches for line in zip(*match)] + return result + + def load_weights_from(self, src): + if not hasattr(self, '_initialized'): + raise AttributeError( + "Please initialize the net before loading weights using " + "the '.initialize()' method.") + + if isinstance(src, str): + src = np.load(src) + if isinstance(src, NeuralNet): + src = src.get_all_params() + + target = self.get_all_params() + src_params = [p.get_value() if hasattr(p, 'get_value') else p + for p in src] + target_params = [p.get_value() for p in target] + + src_shapes = [p.shape for p in src_params] + target_shapes = [p.shape for p in target_params] + matches = self._param_alignment(src_shapes, target_shapes) + + for i, j in matches: + target[j].set_value(src_params[i]) + + if not self.verbose: + continue + param_shape = 'x'.join(map(str, src_params[i].shape)) + param_name = target[j].name + ' ' if target[j].name else None + print("* Loaded parameter {}(shape: {})".format( + param_name, param_shape)) + def __getstate__(self): state = dict(self.__dict__) for attr in ( diff --git a/nolearn/tests/test_lasagne.py b/nolearn/tests/test_lasagne.py index ca1283b..f42660d 100644 --- a/nolearn/tests/test_lasagne.py +++ b/nolearn/tests/test_lasagne.py @@ -115,10 +115,61 @@ def on_epoch_finished(nn, train_history): # Use load_weights_from to initialize an untrained model: nn3 = clone(nn_def) + nn3.initialize() nn3.load_weights_from(nn2) assert np.array_equal(nn3.predict(X_test), y_pred) +def test_lasagne_loading_params_matches(): + # Loading mechanism should find layers with matching parameter + # shapes, even if they are not perfectly aligned. + from nolearn.lasagne import NeuralNet + + layers0 = [('input', InputLayer), + ('dense0', DenseLayer), + ('dense1', DenseLayer), + ('output', DenseLayer)] + net0 = NeuralNet( + layers=layers0, + input_shape=(None, 784), + dense0_num_units=100, + dense1_num_units=200, + output_nonlinearity=softmax, output_num_units=10, + update=nesterov_momentum, + update_learning_rate=0.01, + max_epochs=5, + ) + net0.initialize() + net0.save_weights_to('tmp_params.np') + + layers1 = [('input', InputLayer), + ('dense0', DenseLayer), + ('dense1', DenseLayer), + ('dense2', DenseLayer), + ('output', DenseLayer)] + net1 = NeuralNet( + layers=layers1, + input_shape=(None, 784), + dense0_num_units=100, + dense1_num_units=20, + dense2_num_units=200, + output_nonlinearity=softmax, output_num_units=10, + update=nesterov_momentum, + update_learning_rate=0.01, + max_epochs=5, + ) + net1.initialize() + + # output weights have the same shape but should differ + assert not (net0.layers_['output'].W.get_value() == + net1.layers_['output'].W.get_value()).all() + # after loading, these weights should be equal, despite the + # additional dense layer + net1.load_weights_from('tmp_params.np') + assert (net0.layers_['output'].W.get_value() == + net1.layers_['output'].W.get_value()).all() + + def test_lasagne_functional_grid_search(mnist, monkeypatch): # Make sure that we can satisfy the grid search interface. from nolearn.lasagne import NeuralNet