diff --git a/digitalocean/baseapi.py b/digitalocean/baseapi.py index 6647753..23d3c6d 100644 --- a/digitalocean/baseapi.py +++ b/digitalocean/baseapi.py @@ -178,11 +178,14 @@ def __deal_with_pagination(self, url, method, params, data): def __init_ratelimit(self, headers): # Add the account requests/hour limit - self.ratelimit_limit = headers.get('Ratelimit-Limit', None) + ratelimit_limit = headers.get('Ratelimit-Limit', None) + self.ratelimit_limit = int(ratelimit_limit) if ratelimit_limit is not None else None # Add the account requests remaining - self.ratelimit_remaining = headers.get('Ratelimit-Remaining', None) + ratelimit_remaining = headers.get('Ratelimit-Remaining', None) + self.ratelimit_remaining = int(ratelimit_remaining) if ratelimit_remaining is not None else None # Add the account requests limit reset time - self.ratelimit_reset = headers.get('Ratelimit-Reset', None) + ratelimit_reset = headers.get('Ratelimit-Reset', None) + self.ratelimit_reset = int(ratelimit_reset) if ratelimit_reset is not None else None @property def token(self): @@ -238,6 +241,9 @@ def get_data(self, url, type=GET, params=None): if req.status_code == 404: raise NotFoundError() + # init request limits + self.__init_ratelimit(req.headers) + if len(req.content) == 0: # Raise an error if the request failed and there is no response content req.raise_for_status() @@ -254,9 +260,6 @@ def get_data(self, url, type=GET, params=None): msg = [data[m] for m in ("id", "message") if m in data][1] raise DataReadError(msg) - # init request limits - self.__init_ratelimit(req.headers) - # If there are more elements available (total) than the elements per # page, try to deal with pagination. Note: Breaking the logic on # multiple pages, diff --git a/digitalocean/tests/test_baseapi.py b/digitalocean/tests/test_baseapi.py index 9d6b112..8dc3977 100644 --- a/digitalocean/tests/test_baseapi.py +++ b/digitalocean/tests/test_baseapi.py @@ -1,5 +1,7 @@ import os +from requests.structures import CaseInsensitiveDict + from digitalocean.baseapi import BaseAPI try: import mock @@ -83,3 +85,37 @@ def test_get_data_error_response_no_body(self): mock_5xx_response.return_value.status_code = random.randint(500, 599) # random 5xx status code self.assertRaises(requests.HTTPError, self.manager.get_data, 'test') + + def test_get_data_rate_limit_case_error(self): + with mock.patch.object(self.manager, '_BaseAPI__perform_request') as mock_429: + mock_429.return_value = requests.Response() + mock_429.return_value._content = b'' + mock_429.return_value.status_code = 429 + mock_429.return_value.headers = CaseInsensitiveDict(data={ + 'ratelimit-limit': "1200", + 'ratelimit-remaining': "1193", + 'rateLimit-reset': "1402425459" + }) + + self.assertRaises(requests.HTTPError, self.manager.get_data, 'test') + + self.assertEqual(self.manager.ratelimit_limit, 1200) + self.assertEqual(self.manager.ratelimit_remaining, 1193) + self.assertEqual(self.manager.ratelimit_reset, 1402425459) + + def test_get_data_rate_limit_case_ok(self): + with mock.patch.object(self.manager, '_BaseAPI__perform_request') as mock_200: + mock_200.return_value = requests.Response() + mock_200.return_value._content = b'{}' + mock_200.return_value.status_code = 200 + mock_200.return_value.headers = CaseInsensitiveDict(data={ + 'ratelimit-limit': "1200", + 'ratelimit-remaining': "1193", + 'rateLimit-reset': "1402425459" + }) + + self.manager.get_data('test') + + self.assertEqual(self.manager.ratelimit_limit, 1200) + self.assertEqual(self.manager.ratelimit_remaining, 1193) + self.assertEqual(self.manager.ratelimit_reset, 1402425459)