diff --git a/src/keycloak/openid_connect.py b/src/keycloak/openid_connect.py index 420fa3d..9be1590 100644 --- a/src/keycloak/openid_connect.py +++ b/src/keycloak/openid_connect.py @@ -5,13 +5,12 @@ except ImportError: from urllib import urlencode # noqa: F041 -from jose import jwt +from jose import jwt, ExpiredSignatureError PATH_WELL_KNOWN = "auth/realms/{}/.well-known/openid-configuration" class KeycloakOpenidConnect(WellKnownMixin): - _well_known = None _client_id = None _client_secret = None @@ -138,10 +137,10 @@ def userinfo(self, token): url = self.well_known['userinfo_endpoint'] return self._realm.client.get(url, headers={ - "Authorization": "Bearer {}".format( - token - ) - }) + "Authorization": "Bearer {}".format( + token + ) + }) def uma_ticket(self, token, **kwargs): """ @@ -196,8 +195,9 @@ def authorization_code(self, code, redirect_uri): :rtype: dict :return: Access token response """ - return self._token_request(grant_type='authorization_code', code=code, + token = self._token_request(grant_type='authorization_code', code=code, redirect_uri=redirect_uri) + return Token(token, self) def password_credentials(self, username, password, **kwargs): """ @@ -210,9 +210,10 @@ def password_credentials(self, username, password, **kwargs): :rtype: dict :return: Access token response """ - return self._token_request(grant_type='password', + token = self._token_request(grant_type='password', username=username, password=password, **kwargs) + return Token(token, self) def client_credentials(self, **kwargs): """ @@ -224,7 +225,8 @@ def client_credentials(self, **kwargs): :rtype: dict :return: Access token response """ - return self._token_request(grant_type='client_credentials', **kwargs) + token = self._token_request(grant_type='client_credentials', **kwargs) + return Token(token, self) def refresh_token(self, refresh_token, **kwargs): """ @@ -306,3 +308,26 @@ def _token_request(self, grant_type, **kwargs): return self._realm.client.post(self.get_url('token_endpoint'), data=payload) + + +class Token: + def __init__(self, token, oidc: KeycloakOpenidConnect) -> None: + self.oidc = oidc + self.key = self.oidc.certs()['keys'][0] + self.token = token + + def __getattr__(self, attr): + return self.token[attr] + + def __call__(self): + if self.is_expired(): + print("Token expired, trying a new one") + self.token = self.oidc.refresh_token(self.token['refresh_token']) + return self.token["access_token"] + + def is_expired(self): + try: + self.oidc.decode_token(self.token['access_token'], self.key) + return False + except ExpiredSignatureError: + return True diff --git a/tests/keycloak/test_openid_connect.py b/tests/keycloak/test_openid_connect.py index 5e0d134..8819ca7 100644 --- a/tests/keycloak/test_openid_connect.py +++ b/tests/keycloak/test_openid_connect.py @@ -86,7 +86,7 @@ def test_authorization_url(self): ) def test_authorization_code(self): - response = self.openid_client.authorization_code( + token = self.openid_client.authorization_code( code='some-code', redirect_uri='https://redirect-uri' ) @@ -100,7 +100,7 @@ def test_authorization_code(self): 'redirect_uri': 'https://redirect-uri' } ) - self.assertEqual(response, self.realm.client.post.return_value) + self.assertEqual(token.token, self.realm.client.post.return_value) def test_client_credentials(self): response = self.openid_client.client_credentials( @@ -115,7 +115,7 @@ def test_client_credentials(self): 'scope': 'scope another-scope' } ) - self.assertEqual(response, self.realm.client.post.return_value) + self.assertEqual(response.token, self.realm.client.post.return_value) def test_refresh_token(self): response = self.openid_client.refresh_token(