diff --git a/lshash/lshash.py b/lshash/lshash.py index 5c895a6..1cec6bb 100644 --- a/lshash/lshash.py +++ b/lshash/lshash.py @@ -42,10 +42,14 @@ class LSHash(object): stored if the file does not exist yet. :param overwrite: (optional) Whether to overwrite the matrices file if it already exist + :param seed + (optional) Seed for PRNG. To get consistant results over multiple + instantiation you must use the same seed. """ def __init__(self, hash_size, input_dim, num_hashtables=1, - storage_config=None, matrices_filename=None, overwrite=False): + storage_config=None, matrices_filename=None, + overwrite=False, seed=None): self.hash_size = hash_size self.input_dim = input_dim @@ -60,6 +64,9 @@ def __init__(self, hash_size, input_dim, num_hashtables=1, self.matrices_filename = matrices_filename self.overwrite = overwrite + self.seed = seed + np.random.seed(self.seed) + self._init_uniform_planes() self._init_hashtables() diff --git a/tests/test_seed.py b/tests/test_seed.py new file mode 100644 index 0000000..df5b185 --- /dev/null +++ b/tests/test_seed.py @@ -0,0 +1,15 @@ +import lshash +import numpy as np + +def test_fixed_seed(): + """ fixed seeds should generate the same uniform_planes """ + fixed_seed_lsh = [lshash.LSHash(10, 100, seed=20) for i in range(10)] + uniform_plane_sum = [np.sum(ls.uniform_planes) for ls in fixed_seed_lsh] + assert len(set(uniform_plane_sum)) == 1 + +def test_nonfixed_seed(): + """ when seed is not specified uniform planes should be different """ + nonfixed_seed_lsh = [lshash.LSHash(10, 100) for i in range(10)] + uniform_plane_sum = [np.sum(ls.uniform_planes) for ls in nonfixed_seed_lsh] + assert len(set(uniform_plane_sum)) > 1 +