Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,6 @@
# Python egg metadata, regenerated from source files by setuptools.
/*.egg-info
/*.egg

# PyCharm-Files
/.idea
12 changes: 12 additions & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
services:
- redis-server
language: python
python:
- "2.6"
- "2.7"
- "3.3"
- "3.4"
# command to install dependencies
install: "pip install -r requirements.txt"
# command to run tests
script: nosetests
2 changes: 1 addition & 1 deletion lshash/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@
__license__ = 'MIT'
__version__ = '0.0.4dev'

from lshash import LSHash
from .lshash import LSHash
15 changes: 11 additions & 4 deletions lshash/lshash.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,18 @@
# This module is part of lshash and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php

import sys

if sys.version_info[0] >= 3:
basestring = str
else:
range = xrange

import os
import json
import numpy as np

from storage import storage
from .storage import storage

try:
from bitarray import bitarray
Expand Down Expand Up @@ -92,7 +99,7 @@ def _init_uniform_planes(self):
self.uniform_planes = [t[1] for t in npzfiles]
else:
self.uniform_planes = [self._generate_uniform_planes()
for _ in xrange(self.num_hashtables)]
for _ in range(self.num_hashtables)]
try:
np.savez_compressed(self.matrices_filename,
*self.uniform_planes)
Expand All @@ -101,14 +108,14 @@ def _init_uniform_planes(self):
raise
else:
self.uniform_planes = [self._generate_uniform_planes()
for _ in xrange(self.num_hashtables)]
for _ in range(self.num_hashtables)]

def _init_hashtables(self):
""" Initialize the hash tables such that each record will be in the
form of "[storage1, storage2, ...]" """

self.hash_tables = [storage(self.storage_config, i)
for i in xrange(self.num_hashtables)]
for i in range(self.num_hashtables)]

def _generate_uniform_planes(self):
""" Generate uniformly distributed hyperplanes and return it as a 2D
Expand Down
56 changes: 23 additions & 33 deletions lshash/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# This module is part of lshash and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php

from __future__ import unicode_literals

import json

try:
Expand All @@ -21,29 +23,16 @@ def storage(storage_config, index):
if 'dict' in storage_config:
return InMemoryStorage(storage_config['dict'])
elif 'redis' in storage_config:
storage_config['redis']['db'] = index
return RedisStorage(storage_config['redis'])
return RedisStorage(storage_config['redis'], index)
else:
raise ValueError("Only in-memory dictionary and Redis are supported.")


class BaseStorage(object):
def __init__(self, config):
""" An abstract class used as an adapter for storages. """
raise NotImplementedError

def keys(self):
""" Returns a list of binary hashes that are used as dict keys. """
raise NotImplementedError

def set_val(self, key, val):
""" Set `val` at `key`, note that the `val` must be a string. """
raise NotImplementedError

def get_val(self, key):
""" Return `val` at `key`, note that the `val` must be a string. """
raise NotImplementedError

def append_val(self, key, val):
""" Append `val` to the list stored at `key`.

Expand All @@ -62,44 +51,45 @@ def get_list(self, key):


class InMemoryStorage(BaseStorage):
def __init__(self, config):
def __init__(self, h_index):
self.name = 'dict'
self.storage = dict()

def keys(self):
return self.storage.keys()

def set_val(self, key, val):
self.storage[key] = val

def get_val(self, key):
return self.storage[key]

def append_val(self, key, val):
self.storage.setdefault(key, []).append(val)
self.storage.setdefault(key, set()).update([val])

def get_list(self, key):
return self.storage.get(key, [])
return list(self.storage.get(key, []))


class RedisStorage(BaseStorage):
def __init__(self, config):
def __init__(self, config, h_index):
if not redis:
raise ImportError("redis-py is required to use Redis as storage.")
self.name = 'redis'
self.storage = redis.StrictRedis(**config)
# a single db handles multiple hash tables, each one has prefix ``h[h_index].``
self.h_index = 'h%.2i.' % int(h_index)

def keys(self, pattern="*"):
return self.storage.keys(pattern)

def set_val(self, key, val):
self.storage.set(key, val)
def _list(self, key):
return self.h_index + key

def get_val(self, key):
return self.storage.get(key)
def keys(self, pattern='*'):
# return the keys BUT be agnostic with reference to the hash table
return [k.decode('ascii').split('.')[1] for k in self.storage.keys(self.h_index + pattern)]

def append_val(self, key, val):
self.storage.rpush(key, json.dumps(val))
self.storage.sadd(self._list(key), json.dumps(val))

def get_list(self, key):
return self.storage.lrange(key, 0, -1)
_list = list(self.storage.smembers(self._list(key))) # list elements are plain strings here
_list = [json.loads(el.decode('ascii')) for el in _list] # transform strings into python tuples
for el in _list:
# if len(el) is 2, then el[1] is the extra value associated to the element
if len(el) == 2 and type(el[0]) == list:
el[0] = tuple(el[0])
_list = [tuple(el) for el in _list]
return _list
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
numpy>=1.9.1
redis==2.10.3
Empty file added tests/__init__.py
Empty file.
127 changes: 127 additions & 0 deletions tests/test_lsh.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
import random
import string
from unittest import TestCase
from redis import StrictRedis
from pprint import pprint
import sys
import os

# add the LSHash package to the current python path
sys.path.insert(0, os.path.abspath('../'))
# now we can use our lshash package and not the standard one
from lshash import LSHash


class TestLSHash(TestCase):
num_elements = 100

def setUp(self):
self.els = []
self.el_names = []
for i in range(self.num_elements):
el = [random.randint(0, 100) for _ in range(8)]
elname = ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(10))
self.els.append(tuple(el))
self.el_names.append(elname)

def test_lshash(self):
lsh = LSHash(6, 8, 1)
for i in range(self.num_elements):
lsh.index(list(self.els[i]))
lsh.index(list(self.els[i])) # multiple insertions
hasht = lsh.hash_tables[0]
itms = [hasht.get_list(k) for k in hasht.keys()]
for itm in itms:
assert itms.count(itm) == 1
for el in itm:
assert el in self.els
for el in self.els:
res = lsh.query(list(el), num_results=1, distance_func='euclidean')[0]
# res is a tuple containing the vector and the distance
el_v, el_dist = res
assert el_v in self.els
assert el_dist == 0
del lsh

def test_lshash_extra_val(self):
lsh = LSHash(6, 8, 1)
for i in range(self.num_elements):
lsh.index(list(self.els[i]), self.el_names[i])
hasht = lsh.hash_tables[0]
itms = [hasht.get_list(k) for k in hasht.keys()]
for itm in itms:
for el in itm:
assert el[0] in self.els
assert el[1] in self.el_names
for el in self.els:
# res is a list, so we need to select the first entry only
res = lsh.query(list(el), num_results=1, distance_func='euclidean')[0]
# vector an name are in the first element of the tuple res[0]
el_v, el_name = res[0]
# the distance is in the second element of the tuple
el_dist = res[1]
assert el_v in self.els
assert el_name in self.el_names
assert el_dist == 0
del lsh

def test_lshash_redis(self):
"""
Test external lshash module
"""
config = {"redis": {"host": 'localhost', "port": 6379, "db": 15}}
sr = StrictRedis(**config['redis'])
sr.flushdb()

lsh = LSHash(6, 8, 1, config)
for i in range(self.num_elements):
lsh.index(list(self.els[i]))
lsh.index(list(self.els[i])) # multiple insertions should be prevented by the library

hasht = lsh.hash_tables[0]
itms = [hasht.get_list(k) for k in hasht.keys()]

for itm in itms:
for el in itm:
assert itms.count(itm) == 1 # have multiple insertions been prevented?
assert el in self.els

for el in self.els:
res = lsh.query(list(el), num_results=1, distance_func='euclidean')[0]
el_v, el_dist = res
assert el_v in self.els
assert el_dist == 0
del lsh
sr.flushdb()

def test_lshash_redis_extra_val(self):
"""
Test external lshash module
"""
config = {"redis": {"host": 'localhost', "port": 6379, "db": 15}}
sr = StrictRedis(**config['redis'])
sr.flushdb()

lsh = LSHash(6, 8, 1, config)
for i in range(self.num_elements):
lsh.index(list(self.els[i]), self.el_names[i])
lsh.index(list(self.els[i]), self.el_names[i]) # multiple insertions
hasht = lsh.hash_tables[0]
itms = [hasht.get_list(k) for k in hasht.keys()]
for itm in itms:
assert itms.count(itm) == 1
for el in itm:
assert el[0] in self.els
assert el[1] in self.el_names
for el in self.els:
res = lsh.query(list(el), num_results=1, distance_func='euclidean')[0]
# vector an name are in the first element of the tuple res[0]
el_v, el_name = res[0]
# the distance is in the second element of the tuple
el_dist = res[1]
assert el_v in self.els
assert el_name in self.el_names
assert el_dist == 0
del lsh
sr.flushdb()