diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile index ef2170d..03ed08c 100644 --- a/.devcontainer/Dockerfile +++ b/.devcontainer/Dockerfile @@ -2,7 +2,7 @@ FROM alpine:latest # Install common tools RUN apk add --no-cache bash git \ - python3 py3-pip + python3 py3-pip openssl # Setup default user ARG USERNAME=vscode diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index 146e6b1..767abfb 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -33,6 +33,16 @@ jobs: python-version: ${{ matrix.python-version }} cache: "pip" + - name: Install OpenSSL (Linux) + if: runner.os == 'Linux' + run: sudo apt-get update && sudo apt-get install -y libssl-dev + + # MacOS does not require OpenSSL installation as it is pre-installed + + - name: Install OpenSSL (Windows) + if: runner.os == 'Windows' + run: choco install openssl.light --no-progress + - name: Install dependencies run: | pip install -r requirements.txt diff --git a/.gitignore b/.gitignore index d627f28..5a44d50 100644 --- a/.gitignore +++ b/.gitignore @@ -2,7 +2,8 @@ logs/ *.log -# Ignore all pem files +# Ignore JWT keys +keys/ *.pem # Upload folder diff --git a/Dockerfile b/Dockerfile index 3df21ad..7efa99b 100644 --- a/Dockerfile +++ b/Dockerfile @@ -4,23 +4,26 @@ FROM python:3.13-slim # Set a specific working directory in the container WORKDIR /app -# Install dependencies separately for better caching +# Install dependencies +RUN apt-get update && apt-get install -y --no-install-recommends curl openssl \ + && apt-get clean && rm -rf /var/lib/apt/lists/* + +# Install required Python packages COPY requirements.txt . RUN pip install --no-cache-dir -r requirements.txt # Copy the rest of the application files COPY . . -# Expose the Flask port -EXPOSE 5000 - # Create a non-root user and set permissions for the /app directory RUN adduser --disabled-password --gecos '' apiuser && chown -R apiuser /app USER apiuser +# Expose the Flask port +EXPOSE 5000 + # Add health check for the container HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \ CMD curl --fail http://localhost:5000/ || exit 1 -# Command to run the app -CMD ["python", "app.py"] +ENTRYPOINT ["python3", "app.py"] diff --git a/app.py b/app.py index 7a060a1..dc52048 100644 --- a/app.py +++ b/app.py @@ -1,15 +1,18 @@ import argparse +import os import traceback from flask import Flask, jsonify, request from flask_cors import CORS from werkzeug.exceptions import HTTPException +from config.jwtoken import ACTIVE_KID_FILE from config.logging import setup_logging from config.ratelimit import limiter from config.settings import Config from routes import register_routes from utility.database import extract_error_message +from utility.jwtoken.keys_rotation import rotate_keys app = Flask(__name__) app.config.from_object(Config) @@ -26,6 +29,10 @@ if app.config["TESTING"]: limiter.enabled = False +# Ensure the keys directory and active_kid.txt file exist +if not os.path.exists(ACTIVE_KID_FILE): + rotate_keys() + @app.route("/") def home(): diff --git a/config/jwtoken.py b/config/jwtoken.py new file mode 100644 index 0000000..8c22967 --- /dev/null +++ b/config/jwtoken.py @@ -0,0 +1,8 @@ +from pathlib import Path + +KEYS_DIR = Path("keys") + +ACTIVE_KID_FILE = KEYS_DIR / "active_kid.txt" +CREATED_AT_FILE = "created_at.txt" +PRIVATE_KEY_FILE = "private.pem" +PUBLIC_KEY_FILE = "public.pem" diff --git a/config/logging.py b/config/logging.py index d39258c..455780e 100644 --- a/config/logging.py +++ b/config/logging.py @@ -12,12 +12,12 @@ def setup_logging(): logger.setLevel(logging.DEBUG) # Create a file handler that rotates logs daily - timestamp = datetime.datetime.now().strftime("%Y-%m-%d") + current_time = datetime.datetime.now().strftime("%Y-%m-%d") # Create a directory for logs if it doesn't exist if not os.path.exists(LOGS_DIRECTORY): os.makedirs(LOGS_DIRECTORY) - log_filename = f"{LOGS_DIRECTORY}/{timestamp}.log" + log_filename = f"{LOGS_DIRECTORY}/{current_time}.log" # Set up timed rotating file handler file_handler = TimedRotatingFileHandler( diff --git a/config/settings.py b/config/settings.py index e12fe8f..205ab06 100644 --- a/config/settings.py +++ b/config/settings.py @@ -20,7 +20,6 @@ class Config: MYSQL_CURSORCLASS = "DictCursor" # JWT configuration - JWT_SECRET_KEY = os.getenv("JWT_SECRET_KEY") JWT_ACCESS_TOKEN_EXPIRY = timedelta(hours=1) JWT_REFRESH_TOKEN_EXPIRY = timedelta(days=30) diff --git a/jwt_helper.py b/jwt_helper.py deleted file mode 100644 index f9f4d9e..0000000 --- a/jwt_helper.py +++ /dev/null @@ -1,78 +0,0 @@ -import os -from datetime import datetime, timedelta, timezone -from functools import wraps - -import jwt -from flask import jsonify, request - -JWT_SECRET_KEY = os.getenv("JWT_SECRET_KEY", "SuperSecretKey") -JWT_ACCESS_TOKEN_EXPIRY = timedelta(hours=1) -JWT_REFRESH_TOKEN_EXPIRY = timedelta(days=30) - - -class TokenError(Exception): - """Custom exception for token-related errors.""" - - def __init__(self, message, status_code): - super().__init__(message) - self.status_code = status_code - self.message = message - - -def generate_access_token(person_id: int) -> str: - """Generate a short-lived JWT access token for a user.""" - payload = { - "person_id": person_id, - "exp": datetime.now(timezone.utc) + JWT_ACCESS_TOKEN_EXPIRY, # Expiration - "iat": datetime.now(timezone.utc), # Issued at - "token_type": "access", - } - return jwt.encode(payload, JWT_SECRET_KEY, algorithm="HS256") - - -def generate_refresh_token(person_id: int) -> str: - """Generate a long-lived refresh token for a user.""" - payload = { - "person_id": person_id, - "exp": datetime.now(timezone.utc) + JWT_REFRESH_TOKEN_EXPIRY, - "iat": datetime.now(timezone.utc), - "token_type": "refresh", - } - return jwt.encode(payload, JWT_SECRET_KEY, algorithm="HS256") - - -def extract_token_from_header() -> str: - """Extract the Bearer token from the Authorization header.""" - auth_header = request.headers.get("Authorization") - if not auth_header or not auth_header.startswith("Bearer "): - raise TokenError("Token is missing or improperly formatted", 401) - return auth_header.split("Bearer ")[1] - - -def verify_token(token: str, required_type: str) -> dict: - """Verify and decode a JWT token.""" - try: - decoded = jwt.decode(token, JWT_SECRET_KEY, algorithms=["HS256"]) - if decoded.get("token_type") != required_type: - raise jwt.InvalidTokenError("Invalid token type") - return decoded - except jwt.ExpiredSignatureError: - raise TokenError("Token has expired", 401) - except jwt.InvalidTokenError: - raise TokenError("Invalid token", 401) - - -def token_required(f): - """Decorator to protect routes by requiring a valid token.""" - - @wraps(f) - def decorated(*args, **kwargs): - try: - token = extract_token_from_header() - decoded = verify_token(token, required_type="access") - request.person_id = decoded["person_id"] - return f(*args, **kwargs) - except TokenError as e: - return jsonify(message=e.message), e.status_code - - return decorated diff --git a/tests/test_jwt/__init__.py b/jwtoken/__init__.py similarity index 100% rename from tests/test_jwt/__init__.py rename to jwtoken/__init__.py diff --git a/jwtoken/decorators.py b/jwtoken/decorators.py new file mode 100644 index 0000000..9ee475e --- /dev/null +++ b/jwtoken/decorators.py @@ -0,0 +1,24 @@ +from functools import wraps + +from flask import jsonify, request + +from utility.jwtoken.common import extract_token_from_header + +from .exceptions import TokenError +from .tokens import verify_token + + +def token_required(f): + """Decorator to protect routes by requiring a valid token.""" + + @wraps(f) + def decorated(*args, **kwargs): + try: + token = extract_token_from_header() + decoded = verify_token(token, required_type="access") + request.person_id = decoded["person_id"] + return f(*args, **kwargs) + except TokenError as e: + return jsonify(message=e.message), e.status_code + + return decorated diff --git a/jwtoken/exceptions.py b/jwtoken/exceptions.py new file mode 100644 index 0000000..95b0731 --- /dev/null +++ b/jwtoken/exceptions.py @@ -0,0 +1,7 @@ +class TokenError(Exception): + """Custom exception for token-related errors.""" + + def __init__(self, message, status_code): + super().__init__(message) + self.status_code = status_code + self.message = message diff --git a/jwtoken/tokens.py b/jwtoken/tokens.py new file mode 100644 index 0000000..db551d2 --- /dev/null +++ b/jwtoken/tokens.py @@ -0,0 +1,63 @@ +from datetime import datetime, timedelta, timezone + +import jwt + +from utility.jwtoken.common import get_active_kid, load_private_key, load_public_key + +from .exceptions import TokenError + +JWT_ACCESS_TOKEN_EXPIRY = timedelta(hours=1) +JWT_REFRESH_TOKEN_EXPIRY = timedelta(days=30) + + +def generate_access_token(person_id: int) -> str: + kid = get_active_kid() + private_key = load_private_key(kid) + + current_time = datetime.now(timezone.utc) + payload = { + "person_id": person_id, + "exp": current_time + JWT_ACCESS_TOKEN_EXPIRY, # Expiration + "iat": current_time, # Issued at + "token_type": "access", + } + headers = {"kid": kid} + return jwt.encode(payload, private_key, algorithm="RS256", headers=headers) + + +def generate_refresh_token(person_id: int) -> str: + """Generate a long-lived refresh token for a user.""" + kid = get_active_kid() + private_key = load_private_key(kid) + + current_time = datetime.now(timezone.utc) + payload = { + "person_id": person_id, + "exp": current_time + JWT_REFRESH_TOKEN_EXPIRY, # Expiration + "iat": current_time, # Issued at + "token_type": "refresh", + } + headers = {"kid": kid} + + return jwt.encode(payload, private_key, algorithm="RS256", headers=headers) + + +def verify_token(token: str, required_type: str) -> dict: + """Verify and decode a JWT token.""" + try: + unverified_header = jwt.get_unverified_header(token) + kid = unverified_header.get("kid") + if not kid: + raise TokenError("KID missing in token header", 401) + + public_key = load_public_key(kid) + + decoded = jwt.decode(token, public_key, algorithms=["RS256"]) + if decoded.get("token_type") != required_type: + raise jwt.InvalidTokenError("Invalid token type") + return decoded + + except jwt.ExpiredSignatureError: + raise TokenError("Token has expired", 401) + except jwt.InvalidTokenError: + raise TokenError("Invalid token", 401) diff --git a/routes/authentication.py b/routes/authentication.py index 4d5477d..eb550c4 100644 --- a/routes/authentication.py +++ b/routes/authentication.py @@ -3,15 +3,11 @@ from pymysql import MySQLError from config.ratelimit import limiter -from jwt_helper import ( - TokenError, - extract_token_from_header, - generate_access_token, - generate_refresh_token, - verify_token, -) +from jwtoken.exceptions import TokenError +from jwtoken.tokens import generate_access_token, generate_refresh_token, verify_token from utility.database import database_cursor from utility.encryption import encrypt_email, hash_email, hash_password, verify_password +from utility.jwtoken.common import extract_token_from_header from utility.validation import validate_email, validate_password authentication_blueprint = Blueprint("authentication", __name__) diff --git a/routes/picture.py b/routes/picture.py index dc50ff5..6689cde 100644 --- a/routes/picture.py +++ b/routes/picture.py @@ -3,7 +3,7 @@ from flask import Blueprint, jsonify, request, send_from_directory from config.settings import Config -from jwt_helper import token_required +from jwtoken.decorators import token_required from utility.database import database_cursor picture_blueprint = Blueprint("picture", __name__) diff --git a/tests/test_jwt/test_verify_token.py b/tests/test_jwt/test_verify_token.py deleted file mode 100644 index ca97280..0000000 --- a/tests/test_jwt/test_verify_token.py +++ /dev/null @@ -1,54 +0,0 @@ -from datetime import datetime, timedelta - -import jwt -import pytest -from flask import Flask - -from jwt_helper import JWT_SECRET_KEY, TokenError, verify_token - -app = Flask(__name__) - - -def test_verify_valid_access_token(): - """Test verifying a valid access token.""" - access_token = jwt.encode( - {"token_type": "access"}, JWT_SECRET_KEY, algorithm="HS256" - ) - decoded = verify_token(access_token, "access") - assert decoded["token_type"] == "access" - - -def test_verify_valid_refresh_token(): - """Test verifying a valid refresh token.""" - refresh_token = jwt.encode( - {"token_type": "refresh"}, JWT_SECRET_KEY, algorithm="HS256" - ) - decoded = verify_token(refresh_token, "refresh") - assert decoded["token_type"] == "refresh" - - -def test_verify_token_invalid_type(): - """Test verifying a token with an incorrect type.""" - token = jwt.encode({"token_type": "invalid"}, JWT_SECRET_KEY, algorithm="HS256") - with pytest.raises(TokenError, match="Invalid token") as excinfo: - verify_token(token, "access") - assert excinfo.value.status_code == 401 - - -def test_verify_expired_token(): - """Test verifying an expired token.""" - expired_token = jwt.encode( - {"token_type": "access", "exp": datetime.now() - timedelta(seconds=1)}, - JWT_SECRET_KEY, - algorithm="HS256", - ) - with pytest.raises(TokenError, match="Token has expired") as excinfo: - verify_token(expired_token, "access") - assert excinfo.value.status_code == 401 - - -def test_verify_invalid_token(): - """Test verifying an invalid token.""" - with pytest.raises(TokenError, match="Invalid token") as excinfo: - verify_token("invalid_token", "access") - assert excinfo.value.status_code == 401 diff --git a/tests/test_jwtoken/__init__.py b/tests/test_jwtoken/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_jwt/conftest.py b/tests/test_jwtoken/conftest.py similarity index 71% rename from tests/test_jwt/conftest.py rename to tests/test_jwtoken/conftest.py index b401401..884c3c3 100644 --- a/tests/test_jwt/conftest.py +++ b/tests/test_jwtoken/conftest.py @@ -1,7 +1,8 @@ +import jwt import pytest from flask import Flask -from jwt_helper import generate_access_token +from jwtoken.tokens import generate_access_token @pytest.fixture @@ -21,3 +22,8 @@ def sample_token(): def sample_access_token(sample_person_id): """Provide a sample access token for testing""" return generate_access_token(sample_person_id) + + +@pytest.fixture +def sample_kid(sample_access_token): + return jwt.get_unverified_header(sample_access_token).get("kid") diff --git a/tests/test_jwt/test_extract_token_from_header.py b/tests/test_jwtoken/test_extract_token_from_header.py similarity index 93% rename from tests/test_jwt/test_extract_token_from_header.py rename to tests/test_jwtoken/test_extract_token_from_header.py index e83eb92..eee6a57 100644 --- a/tests/test_jwt/test_extract_token_from_header.py +++ b/tests/test_jwtoken/test_extract_token_from_header.py @@ -1,7 +1,8 @@ import pytest from flask import Flask -from jwt_helper import TokenError, extract_token_from_header +from jwtoken.exceptions import TokenError +from utility.jwtoken.common import extract_token_from_header app = Flask(__name__) diff --git a/tests/test_jwt/test_generate_access_token.py b/tests/test_jwtoken/test_generate_access_token.py similarity index 71% rename from tests/test_jwt/test_generate_access_token.py rename to tests/test_jwtoken/test_generate_access_token.py index ab0c57d..bae0c63 100644 --- a/tests/test_jwt/test_generate_access_token.py +++ b/tests/test_jwtoken/test_generate_access_token.py @@ -1,6 +1,6 @@ import jwt -from jwt_helper import JWT_ACCESS_TOKEN_EXPIRY, JWT_SECRET_KEY +from jwtoken.tokens import JWT_ACCESS_TOKEN_EXPIRY, load_public_key def test_access_token_type(sample_access_token): @@ -8,28 +8,32 @@ def test_access_token_type(sample_access_token): assert isinstance(sample_access_token, str) -def test_decoded_access_token(sample_person_id, sample_access_token): +def test_decoded_access_token(sample_person_id, sample_access_token, sample_kid): """ Ensure the generated access token can be decoded and contains the correct payload - Check if the payload contains the correct person ID - Check if the token has an expiration time - Check if the token type is 'access' """ - payload = jwt.decode(sample_access_token, JWT_SECRET_KEY, algorithms=["HS256"]) + public_key = load_public_key(sample_kid) + + payload = jwt.decode(sample_access_token, public_key, algorithms=["RS256"]) assert payload["person_id"] == sample_person_id assert "exp" in payload assert payload["token_type"] == "access" -def test_access_token_expiration(sample_access_token): +def test_access_token_expiration(sample_access_token, sample_kid): """ Ensure the generated access token has a valid expiration time - Check if the expiration time is greater than 0 - Check if the expiration time is greater than the issued at time - Check if the token is not expired """ - payload = jwt.decode(sample_access_token, JWT_SECRET_KEY, algorithms=["HS256"]) + public_key = load_public_key(sample_kid) + + payload = jwt.decode(sample_access_token, public_key, algorithms=["RS256"]) assert payload["exp"] > 0 assert payload["exp"] > payload["iat"] diff --git a/tests/test_jwtoken/test_invalid_token.py b/tests/test_jwtoken/test_invalid_token.py new file mode 100644 index 0000000..8e57802 --- /dev/null +++ b/tests/test_jwtoken/test_invalid_token.py @@ -0,0 +1,84 @@ +from datetime import datetime, timedelta, timezone + +import jwt +import pytest +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import rsa +from jwt.exceptions import InvalidKeyError + +from jwtoken.exceptions import TokenError +from jwtoken.tokens import generate_access_token, verify_token +from utility.jwtoken.common import get_active_kid + + +@pytest.fixture +def active_kid(): + return get_active_kid() + + +@pytest.fixture +def fake_private_key(): + key = rsa.generate_private_key(public_exponent=65537, key_size=2048) + return key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + + +@pytest.fixture +def public_key(active_kid): + with open(f"keys/{active_kid}/public.pem", "rb") as f: + return f.read() + + +@pytest.fixture +def sample_payload(sample_person_id): + current_time = datetime.now(timezone.utc) + return { + "person_id": sample_person_id, + "exp": current_time + timedelta(minutes=10), + "iat": current_time, + "token_type": "access", + } + + +def test_hs256_forged_token_rejected(sample_person_id, public_key, active_kid): + headers = {"kid": active_kid, "alg": "HS256"} + current_time = datetime.now(timezone.utc) + payload = { + "person_id": sample_person_id, + "exp": current_time + timedelta(minutes=10), + "iat": current_time, + "token_type": "access", + } + + with pytest.raises(InvalidKeyError, match="asymmetric key.*HMAC"): + jwt.encode(payload, public_key, algorithm="HS256", headers=headers) + + +def test_tampered_token_expiry_extension(sample_person_id, fake_private_key): + original_token = generate_access_token(sample_person_id) + original_payload = jwt.decode(original_token, options={"verify_signature": False}) + original_header = jwt.get_unverified_header(original_token) + + modified_payload = original_payload.copy() + modified_payload["exp"] = datetime.now(timezone.utc) + timedelta(days=365) + + forged_token = jwt.encode( + modified_payload, fake_private_key, algorithm="RS256", headers=original_header + ) + + with pytest.raises(TokenError, match="Invalid token"): + verify_token(forged_token, "access") + + +def test_unknown_kid_rejected(fake_private_key, sample_payload): + headers = {"kid": "fake123456"} + + fake_token = jwt.encode( + sample_payload, fake_private_key, algorithm="RS256", headers=headers + ) + + with pytest.raises(TokenError, match="Unknown key ID"): + verify_token(fake_token, "access") diff --git a/tests/test_jwt/test_refresh_token.py b/tests/test_jwtoken/test_refresh_token.py similarity index 68% rename from tests/test_jwt/test_refresh_token.py rename to tests/test_jwtoken/test_refresh_token.py index ffd9a4a..6dd6889 100644 --- a/tests/test_jwt/test_refresh_token.py +++ b/tests/test_jwtoken/test_refresh_token.py @@ -1,7 +1,11 @@ import jwt import pytest -from jwt_helper import JWT_REFRESH_TOKEN_EXPIRY, JWT_SECRET_KEY, generate_refresh_token +from jwtoken.tokens import ( + JWT_REFRESH_TOKEN_EXPIRY, + generate_refresh_token, + load_public_key, +) @pytest.fixture @@ -15,28 +19,35 @@ def test_refresh_token_type(sample_refresh_token): assert isinstance(sample_refresh_token, str) -def test_decoded_refresh_token_decoded(sample_person_id, sample_refresh_token): +def test_decoded_refresh_token_decoded( + sample_person_id, sample_refresh_token, sample_kid +): """ Ensure the generated refresh token can be decoded and contains the correct payload - Check if the payload contains the correct person ID - Check if the token has an expiration time - Check if the token type is 'refresh' """ - payload = jwt.decode(sample_refresh_token, JWT_SECRET_KEY, algorithms=["HS256"]) + + public_key = load_public_key(sample_kid) + + payload = jwt.decode(sample_refresh_token, public_key, algorithms=["RS256"]) assert payload["person_id"] == sample_person_id assert "exp" in payload assert payload["token_type"] == "refresh" -def test_refresh_token_expiration(sample_refresh_token): +def test_refresh_token_expiration(sample_refresh_token, sample_kid): """ Ensure the generated refresh token has a valid expiration time - Check if the expiration time is greater than 0 - Check if the expiration time is greater than the issued at time - Check if the token is not expired """ - payload = jwt.decode(sample_refresh_token, JWT_SECRET_KEY, algorithms=["HS256"]) + public_key = load_public_key(sample_kid) + + payload = jwt.decode(sample_refresh_token, public_key, algorithms=["RS256"]) assert payload["exp"] > 0 assert payload["exp"] > payload["iat"] diff --git a/tests/test_jwt/test_token_required.py b/tests/test_jwtoken/test_token_required.py similarity index 96% rename from tests/test_jwt/test_token_required.py rename to tests/test_jwtoken/test_token_required.py index c74862b..0fefd0a 100644 --- a/tests/test_jwt/test_token_required.py +++ b/tests/test_jwtoken/test_token_required.py @@ -1,6 +1,6 @@ from flask import Flask, jsonify -from jwt_helper import token_required +from jwtoken.decorators import token_required app = Flask(__name__) diff --git a/tests/test_jwtoken/test_verify_token.py b/tests/test_jwtoken/test_verify_token.py new file mode 100644 index 0000000..03efe84 --- /dev/null +++ b/tests/test_jwtoken/test_verify_token.py @@ -0,0 +1,86 @@ +from datetime import datetime, timedelta, timezone + +import jwt +import pytest +from flask import Flask + +from jwtoken.exceptions import TokenError +from jwtoken.tokens import ( + generate_access_token, + generate_refresh_token, + load_private_key, + verify_token, +) +from utility.jwtoken.common import get_active_kid + +app = Flask(__name__) + + +def test_verify_valid_access_token(sample_person_id): + """Test verifying a valid access token.""" + token = generate_access_token(sample_person_id) + decoded = verify_token(token, "access") + assert decoded["person_id"] == sample_person_id + assert decoded["token_type"] == "access" + + +def test_verify_valid_refresh_token(sample_person_id): + """Test verifying a valid refresh token.""" + token = generate_refresh_token(sample_person_id) + decoded = verify_token(token, "refresh") + assert decoded["person_id"] == sample_person_id + assert decoded["token_type"] == "refresh" + + +def test_verify_token_invalid_type(sample_person_id): + """Test verifying a token with an incorrect type.""" + kid = get_active_kid() + private_key = load_private_key(kid) + + # Create a token with token_type = "invalid" + current_time = datetime.now(timezone.utc) + token = jwt.encode( + { + "person_id": sample_person_id, + "token_type": "invalid", + "exp": current_time + timedelta(minutes=5), + "iat": current_time, + }, + private_key, + algorithm="RS256", + headers={"kid": kid}, + ) + + with pytest.raises(TokenError, match="Invalid token") as excinfo: + verify_token(token, "access") + assert excinfo.value.status_code == 401 + + +def test_verify_expired_token(sample_person_id): + """Test verifying an expired token.""" + kid = get_active_kid() + private_key = load_private_key(kid) + + current_time = datetime.now(timezone.utc) + expired_token = jwt.encode( + { + "person_id": sample_person_id, + "token_type": "access", + "exp": current_time - timedelta(seconds=1), + "iat": current_time - timedelta(hours=1), + }, + private_key, + algorithm="RS256", + headers={"kid": kid}, + ) + + with pytest.raises(TokenError, match="Token has expired") as excinfo: + verify_token(expired_token, "access") + assert excinfo.value.status_code == 401 + + +def test_verify_invalid_token(): + """Test verifying a malformed or tampered token.""" + with pytest.raises(TokenError, match="Invalid token") as excinfo: + verify_token("not_a_real_token", "access") + assert excinfo.value.status_code == 401 diff --git a/tests/test_routes/conftest.py b/tests/test_routes/conftest.py index 21350e9..edd390d 100644 --- a/tests/test_routes/conftest.py +++ b/tests/test_routes/conftest.py @@ -1,7 +1,7 @@ import pytest from flask import Flask -from jwt_helper import generate_access_token +from jwtoken.tokens import generate_access_token @pytest.fixture diff --git a/tests/test_routes/test_authentication/test_refresh.py b/tests/test_routes/test_authentication/test_refresh.py index a8e5b83..e7d745c 100644 --- a/tests/test_routes/test_authentication/test_refresh.py +++ b/tests/test_routes/test_authentication/test_refresh.py @@ -1,16 +1,11 @@ -import datetime +from datetime import datetime, timedelta, timezone import jwt import pytest from flask.testing import FlaskClient -from jwt_helper import JWT_SECRET_KEY, generate_refresh_token - - -@pytest.fixture -def sample_person_id() -> int: - """Provide a sample person ID for testing""" - return 12345 +from jwtoken.tokens import generate_refresh_token, load_private_key +from utility.jwtoken.common import get_active_kid @pytest.fixture @@ -22,23 +17,36 @@ def sample_refresh_token(sample_person_id: int) -> str: @pytest.fixture def sample_expired_token(sample_person_id) -> str: """Generate a deliberately expired JWT refresh token.""" + kid = get_active_kid() + private_key = load_private_key(kid) + + current_time = datetime.now(timezone.utc) payload = { "person_id": sample_person_id, - "exp": datetime.datetime.now(datetime.timezone.utc) - - datetime.timedelta(seconds=1), # Already expired - "iat": datetime.datetime.now(datetime.timezone.utc) - - datetime.timedelta(hours=1), + "exp": current_time - timedelta(seconds=1), # Already expired + "iat": current_time - timedelta(hours=1), "token_type": "refresh", } - return jwt.encode(payload, JWT_SECRET_KEY, algorithm="HS256") + + return jwt.encode(payload, private_key, algorithm="RS256", headers={"kid": kid}) -def test_refresh_token_success(client, sample_refresh_token): +def test_refresh_token_success(client: FlaskClient, sample_refresh_token): """Test the refresh token endpoint with a valid token""" headers = {"Authorization": f"Bearer {sample_refresh_token}"} response = client.post("/refresh", headers=headers) + + assert response.status_code == 200 encoded_token = response.json["access_token"] - token = jwt.decode(encoded_token, JWT_SECRET_KEY, algorithms=["HS256"]) + + # Decode the returned access token using RS256 + unverified_header = jwt.get_unverified_header(encoded_token) + kid = unverified_header["kid"] + public_key_path = f"keys/{kid}/public.pem" + with open(public_key_path, "rb") as f: + public_key = f.read() + + token = jwt.decode(encoded_token, public_key, algorithms=["RS256"]) assert "person_id" in token assert "exp" in token @@ -66,7 +74,6 @@ def test_refresh_token_missing_token(client: FlaskClient): def test_refresh_token_expired_token(client: FlaskClient, sample_expired_token): """Test the refresh token endpoint with an expired token""" headers = {"Authorization": f"Bearer {sample_expired_token}"} - response = client.post("/refresh", headers=headers) assert response.status_code == 401 diff --git a/utility/jwtoken/__init__.py b/utility/jwtoken/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/utility/jwtoken/common.py b/utility/jwtoken/common.py new file mode 100644 index 0000000..3986e63 --- /dev/null +++ b/utility/jwtoken/common.py @@ -0,0 +1,29 @@ +from flask import request + +from jwtoken.exceptions import TokenError + + +def extract_token_from_header() -> str: + """Extract the Bearer token from the Authorization header.""" + auth_header = request.headers.get("Authorization") + if not auth_header or not auth_header.startswith("Bearer "): + raise TokenError("Token is missing or improperly formatted", 401) + return auth_header.split("Bearer ")[1] + + +def get_active_kid(): + with open("keys/active_kid.txt", "r") as f: + return f.read().strip() + + +def load_private_key(kid): + with open(f"keys/{kid}/private.pem", "rb") as f: + return f.read() + + +def load_public_key(kid): + try: + with open(f"keys/{kid}/public.pem", "rb") as f: + return f.read() + except FileNotFoundError: + raise TokenError("Unknown key ID", 401) diff --git a/utility/jwtoken/keys_cleanup.py b/utility/jwtoken/keys_cleanup.py new file mode 100644 index 0000000..56be1fc --- /dev/null +++ b/utility/jwtoken/keys_cleanup.py @@ -0,0 +1,39 @@ +""" +To run this script ensure you're in the root directory of the project. +Run the script with: `python3 -m utility.jwtoken.keys_cleanup` +""" + +from datetime import datetime, timedelta, timezone + +from config.jwtoken import CREATED_AT_FILE, KEYS_DIR +from jwtoken.tokens import JWT_REFRESH_TOKEN_EXPIRY +from utility.jwtoken.common import get_active_kid + +EXPIRY_DAYS = JWT_REFRESH_TOKEN_EXPIRY.days + 1 + + +def cleanup_old_keys(): + current_time = datetime.now(timezone.utc) + active_kid = get_active_kid() + + for kid_dir in KEYS_DIR.iterdir(): + if not kid_dir.is_dir(): + continue + if kid_dir.name == active_kid: + continue # Don't delete active key + created_at_file = kid_dir / CREATED_AT_FILE + if not created_at_file.exists(): + continue # Skip keys without metadata + + with open(created_at_file, "r") as f: + created_at = datetime.fromisoformat(f.read().strip()) + + if (current_time - created_at) > timedelta(days=EXPIRY_DAYS): + print(f"Deleting expired key: {kid_dir.name}") + for item in kid_dir.iterdir(): + item.unlink() + kid_dir.rmdir() + + +if __name__ == "__main__": + cleanup_old_keys() diff --git a/utility/jwtoken/keys_rotation.py b/utility/jwtoken/keys_rotation.py new file mode 100644 index 0000000..41759ec --- /dev/null +++ b/utility/jwtoken/keys_rotation.py @@ -0,0 +1,65 @@ +""" +Rotate the keys used for JWT signing. +Run the script with: `python3 -m utility.jwtoken.keys_rotation` +""" + +import secrets +import subprocess +from datetime import datetime, timezone + +from config.jwtoken import ( + ACTIVE_KID_FILE, + CREATED_AT_FILE, + KEYS_DIR, + PRIVATE_KEY_FILE, + PUBLIC_KEY_FILE, +) + + +def rotate_keys(): + new_kid = secrets.token_hex(8) + new_key_dir = KEYS_DIR / new_kid + new_key_dir.mkdir(parents=True, exist_ok=False) + + private_key_path = new_key_dir / PRIVATE_KEY_FILE + public_key_path = new_key_dir / PUBLIC_KEY_FILE + created_at_path = new_key_dir / CREATED_AT_FILE + + # Generate private key + subprocess.run( + [ + "openssl", + "genpkey", + "-algorithm", + "RSA", + "-out", + str(private_key_path), + "-pkeyopt", + "rsa_keygen_bits:2048", + ], + check=True, + ) + + # Extract public key + subprocess.run( + [ + "openssl", + "rsa", + "-pubout", + "-in", + str(private_key_path), + "-out", + str(public_key_path), + ], + check=True, + ) + + # Set new active KID + ACTIVE_KID_FILE.write_text(new_kid) + + # Save creation time + created_at_path.write_text(datetime.now(timezone.utc).isoformat()) + + +if __name__ == "__main__": + rotate_keys()