Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 9 additions & 10 deletions txrestapi/json_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import time
import six

from six import PY2
from six import PY2, b

from functools import wraps

Expand Down Expand Up @@ -36,11 +36,11 @@ def _setHeaders(self, request):
Those headers will allow you to call API methods from web browsers, they require CORS:
https://en.wikipedia.org/wiki/Cross-origin_resource_sharing
"""
request.responseHeaders.addRawHeader(b'content-type', b'application/json')
request.responseHeaders.addRawHeader(b'Access-Control-Allow-Origin', b'*')
request.responseHeaders.addRawHeader(b'Access-Control-Allow-Methods', b'GET, POST, PUT, DELETE')
request.responseHeaders.addRawHeader(b'Access-Control-Allow-Headers', b'x-prototype-version,x-requested-with')
request.responseHeaders.addRawHeader(b'Access-Control-Max-Age', 2520)
request.responseHeaders.addRawHeader(b('content-type'), b('application/json'))
request.responseHeaders.addRawHeader(b('Access-Control-Allow-Origin'), b('*'))
request.responseHeaders.addRawHeader(b('Access-Control-Allow-Methods'), b('GET, POST, PUT, DELETE'))
request.responseHeaders.addRawHeader(b('Access-Control-Allow-Headers'), b('x-prototype-version,x-requested-with'))
request.responseHeaders.addRawHeader(b('Access-Control-Max-Age'), 2520)
return request

def render(self, request):
Expand Down Expand Up @@ -123,12 +123,11 @@ def __init__(self, *args, **kwargs):
self._registry = []

def _get_callback(self, request):
request_method = request.method
path_to_check = getattr(request, '_remaining_path', request.path)
if not isinstance(path_to_check, six.text_type):
path_to_check = path_to_check.decode()
if not isinstance(path_to_check, six.binary_type):
path_to_check = path_to_check.encode()
for m, r, cb in self._registry:
if m == request_method or m == b'ALL':
if m == request.method or m == b('ALL'):
result = r.search(path_to_check)
if result:
request._remaining_path = path_to_check[result.span()[1]:]
Expand Down
14 changes: 11 additions & 3 deletions txrestapi/resource.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import re
from six import PY2, PY3, b, u
import six
from six import PY2, b, u
if PY2:
from itertools import ifilter as filter
from functools import wraps
Expand Down Expand Up @@ -48,6 +49,8 @@ def __init__(self, *args, **kwargs):
def _get_callback(self, request):
filterf = lambda t:t[0] in (request.method, b('ALL'))
path_to_check = getattr(request, '_remaining_path', request.path)
if not isinstance(path_to_check, six.binary_type):
path_to_check = path_to_check.encode()
for m, r, cb in filter(filterf, self._registry):
result = r.search(path_to_check)
if result:
Expand All @@ -56,10 +59,15 @@ def _get_callback(self, request):
return None, None

def register(self, method, regex, callback):
self._registry.append((method, re.compile(regex.decode()), callback))
if not isinstance(regex, six.text_type):
regex = regex.decode()
self._registry.append((method, re.compile(regex), callback))

def unregister(self, method=None, regex=None, callback=None):
if regex is not None: regex = re.compile(regex.decode())
if regex is not None:
if not isinstance(regex, six.text_type):
regex = regex.decode()
regex = re.compile(regex)
for m, r, cb in self._registry[:]:
if not method or (method and m==method):
if not regex or (regex and r==regex):
Expand Down
94 changes: 84 additions & 10 deletions txrestapi/tests.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import txrestapi
__package__="txrestapi"
import re
import json
import base64
import os.path
import doctest
from six import PY2, b, u
Expand All @@ -11,21 +13,33 @@
from twisted.web.client import getPage
from twisted.trial import unittest
from .resource import APIResource
from .json_resource import JsonAPIResource
from .methods import GET, PUT


class FakeChannel(object):
transport = None

def getPeer(self):
return None

def getHost(self):
return None


def getRequest(method, url):
req = Request(FakeChannel(), None)
req.method = method
req.path = url
return req


class APIResourceTest(unittest.TestCase):

test_class = APIResource

def test_returns_normal_resources(self):
r = APIResource()
r = self.test_class()
a = Resource()
r.putChild(b('a'), a)
req = Request(FakeChannel(), None)
Expand All @@ -34,13 +48,18 @@ def test_returns_normal_resources(self):

def test_registry(self):
compiled = re.compile(b('regex'))
r = APIResource()
r = self.test_class()
r.register(b('GET'), b('regex'), None)
self.assertEqual([x[0] for x in r._registry], [b('GET')])
self.assertEqual(r._registry[0], (b('GET'), compiled, None))
self.assertEqual(r._registry[0][0], b('GET'))
# This doesn't work:
# FailTest: <_sre.SRE_Pattern object at 0x1052f6990> != <_sre.SRE_Pattern object at 0x1052f6d78>
# self.assertEqual(r._registry[0][1], compiled)
self.assertEqual(r._registry[0][1].pattern, compiled.pattern)
self.assertEqual(r._registry[0][2], None)

def test_method_matching(self):
r = APIResource()
r = self.test_class()
r.register(b('GET'), b('regex'), 1)
r.register(b('PUT'), b('regex'), 2)
r.register(b('GET'), b('another'), 3)
Expand Down Expand Up @@ -68,15 +87,15 @@ def test_callback(self):
marker = object()
def cb(request):
return marker
r = APIResource()
r = self.test_class()
r.register(b('GET'), b('regex'), cb)
req = getRequest(b('GET'), b('regex'))
result = r.getChild(b('regex'), req)
self.assertEqual(result.render(req), marker)

def test_longerpath(self):
marker = object()
r = APIResource()
r = self.test_class()
def cb(request):
return marker
r.register(b('GET'), b('/regex/a/b/c'), cb)
Expand All @@ -85,7 +104,7 @@ def cb(request):
self.assertEqual(result.render(req), marker)

def test_args(self):
r = APIResource()
r = self.test_class()
def cb(request, **kwargs):
return kwargs
r.register(b('GET'), b('/(?P<a>[^/]*)/a/(?P<b>[^/]*)/c'), cb)
Expand All @@ -94,7 +113,7 @@ def cb(request, **kwargs):
self.assertEqual(sorted(result.render(req).keys()), ['a', 'b'])

def test_order(self):
r = APIResource()
r = self.test_class()
def cb1(request, **kwargs):
kwargs.update({'cb1':True})
return kwargs
Expand All @@ -109,14 +128,14 @@ def cb(request, **kwargs):
self.assert_('cb1' in result.render(req))

def test_no_resource(self):
r = APIResource()
r = self.test_class()
r.register(b('GET'), b('^/(?P<a>[^/]*)/a/(?P<b>[^/]*)$'), None)
req = getRequest(b('GET'), b('/definitely/not/a/match'))
result = r.getChild(b('regex'), req)
self.assert_(isinstance(result, NoResource))

def test_all(self):
r = APIResource()
r = self.test_class()
def get_cb(r): return b('GET')
def put_cb(r): return b('PUT')
def all_cb(r): return b('ALL')
Expand All @@ -129,6 +148,61 @@ def all_cb(r): return b('ALL')
result = r.getChild(b('path'), req)
self.assertEqual(result.render(req), b('ALL') if method==b('PUT') else method)


class JSONAPIResourceTest(APIResourceTest):

test_class = JsonAPIResource

def test_all(self):
r = self.test_class()
def get_cb(r): return {'method': b('GET'), }
def put_cb(r): return {'method': b('PUT'), }
def all_cb(r): return {'method': b('ALL'), }
r.register(b('GET'), b('^path'), get_cb)
r.register(b('ALL'), b('^path'), all_cb)
r.register(b('PUT'), b('^path'), put_cb)
# Test that the ALL registration picks it up before the PUT one
for method in (b('GET'), b('PUT'), b('ALL')):
req = getRequest(method, b('path'))
result = r.getChild(b('path'), req)
self.assertEqual(json.loads(result.render(req))['method'], b('ALL') if method==b('PUT') else method)

def test_args(self):
r = self.test_class()
def cb(request, **kwargs):
return {'kwargs': kwargs, }
r.register(b('GET'), b('/(?P<a>[^/]*)/a/(?P<b>[^/]*)/c'), cb)
req = getRequest(b('GET'), b('/regex/a/b/c'))
result = r.getChild(b('regex'), req)
self.assertEqual(sorted(json.loads(result.render(req))['kwargs'].keys()), ['a', 'b'])

def test_callback(self):
marker = base64.b64encode(os.urandom(20))
def cb(request):
return {'marker': marker, }
r = self.test_class()
r.register(b('GET'), b('regex'), cb)
req = getRequest(b('GET'), b('regex'))
result = r.getChild(b('regex'), req)
self.assertEqual(json.loads(result.render(req))['marker'], marker)

def test_longerpath(self):
marker = base64.b64encode(os.urandom(20))
r = self.test_class()
def cb(request):
return {'marker': marker, }
r.register(b('GET'), b('/regex/a/b/c'), cb)
req = getRequest(b('GET'), b('/regex/a/b/c'))
result = r.getChild(b('regex'), req)
self.assertEqual(json.loads(result.render(req))['marker'], marker)

def test_no_resource(self):
r = self.test_class()
r.register(b('GET'), b('^/(?P<a>[^/]*)/a/(?P<b>[^/]*)$'), None)
req = getRequest(b('GET'), b('/definitely/not/a/match'))
result = r.getChild(b('regex'), req)
self.assertEqual(json.loads(result.render(req))['errors'], ["path 'regex' not found", ])


class TestResource(Resource):
isLeaf = True
Expand Down