Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions digitalocean/baseapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,11 +178,14 @@ def __deal_with_pagination(self, url, method, params, data):

def __init_ratelimit(self, headers):
# Add the account requests/hour limit
self.ratelimit_limit = headers.get('Ratelimit-Limit', None)
ratelimit_limit = headers.get('Ratelimit-Limit', None)
self.ratelimit_limit = int(ratelimit_limit) if ratelimit_limit is not None else None
# Add the account requests remaining
self.ratelimit_remaining = headers.get('Ratelimit-Remaining', None)
ratelimit_remaining = headers.get('Ratelimit-Remaining', None)
self.ratelimit_remaining = int(ratelimit_remaining) if ratelimit_remaining is not None else None
# Add the account requests limit reset time
self.ratelimit_reset = headers.get('Ratelimit-Reset', None)
ratelimit_reset = headers.get('Ratelimit-Reset', None)
self.ratelimit_reset = int(ratelimit_reset) if ratelimit_reset is not None else None

@property
def token(self):
Expand Down Expand Up @@ -238,6 +241,9 @@ def get_data(self, url, type=GET, params=None):
if req.status_code == 404:
raise NotFoundError()

# init request limits
self.__init_ratelimit(req.headers)

if len(req.content) == 0:
# Raise an error if the request failed and there is no response content
req.raise_for_status()
Expand All @@ -254,9 +260,6 @@ def get_data(self, url, type=GET, params=None):
msg = [data[m] for m in ("id", "message") if m in data][1]
raise DataReadError(msg)

# init request limits
self.__init_ratelimit(req.headers)

# If there are more elements available (total) than the elements per
# page, try to deal with pagination. Note: Breaking the logic on
# multiple pages,
Expand Down
36 changes: 36 additions & 0 deletions digitalocean/tests/test_baseapi.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import os

from requests.structures import CaseInsensitiveDict

from digitalocean.baseapi import BaseAPI
try:
import mock
Expand Down Expand Up @@ -83,3 +85,37 @@ def test_get_data_error_response_no_body(self):
mock_5xx_response.return_value.status_code = random.randint(500, 599) # random 5xx status code

self.assertRaises(requests.HTTPError, self.manager.get_data, 'test')

def test_get_data_rate_limit_case_error(self):
with mock.patch.object(self.manager, '_BaseAPI__perform_request') as mock_429:
mock_429.return_value = requests.Response()
mock_429.return_value._content = b''
mock_429.return_value.status_code = 429
mock_429.return_value.headers = CaseInsensitiveDict(data={
'ratelimit-limit': "1200",
'ratelimit-remaining': "1193",
'rateLimit-reset': "1402425459"
})

self.assertRaises(requests.HTTPError, self.manager.get_data, 'test')

self.assertEqual(self.manager.ratelimit_limit, 1200)
self.assertEqual(self.manager.ratelimit_remaining, 1193)
self.assertEqual(self.manager.ratelimit_reset, 1402425459)

def test_get_data_rate_limit_case_ok(self):
with mock.patch.object(self.manager, '_BaseAPI__perform_request') as mock_200:
mock_200.return_value = requests.Response()
mock_200.return_value._content = b'{}'
mock_200.return_value.status_code = 200
mock_200.return_value.headers = CaseInsensitiveDict(data={
'ratelimit-limit': "1200",
'ratelimit-remaining': "1193",
'rateLimit-reset': "1402425459"
})

self.manager.get_data('test')

self.assertEqual(self.manager.ratelimit_limit, 1200)
self.assertEqual(self.manager.ratelimit_remaining, 1193)
self.assertEqual(self.manager.ratelimit_reset, 1402425459)