From f7ad92bd6dc1f325daa754923f2f53787dec7208 Mon Sep 17 00:00:00 2001 From: VincentAdamNemessis Date: Fri, 10 Jan 2025 17:26:57 +0800 Subject: [PATCH 1/4] fix: rename param --- core/chat.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/core/chat.py b/core/chat.py index 5c25e35..184007f 100644 --- a/core/chat.py +++ b/core/chat.py @@ -50,11 +50,11 @@ async def websocket_endpoint( clients.append(websocket) # 获取客户端 IP 和生成随机用户名 - client_ip = str(websocket.client) + client_real_ip = str(websocket.client) username = generate_random_username() # 保存映射关系 - client_map[websocket] = {"username": username, "ip": client_ip} + client_map[websocket] = {"username": username, "ip": client_real_ip} # 向客户端发送历史消息 chat_history = await get_chat_history(rds_session) @@ -68,10 +68,10 @@ async def websocket_endpoint( # 获取客户端的用户名和 IP 地址 client_username = client_map[websocket]["username"] - client_ip = client_map[websocket]["ip"] + client_real_ip = client_map[websocket]["ip"] # 将消息存储到 Redis - await store_message_in_redis(message, client_username, client_ip, rds_session) + await store_message_in_redis(message, client_username, client_real_ip, rds_session) # 向所有连接的客户端广播消息 for client in clients: From 81cc5b4be91d92cfe3349d7de77b2bd49ba12783 Mon Sep 17 00:00:00 2001 From: VincentAdamNemessis Date: Sat, 11 Jan 2025 15:46:25 +0800 Subject: [PATCH 2/4] feat: - add file upload with websocket and post - add file upload support with fastapi(storage directory and third party package) - add basic logger record --- .gitignore | 6 +- app.py | 15 +++- core/chat.py | 105 ++++++++++++++++++------- poetry.lock | 51 +++++++++++- pyproject.toml | 6 +- templates/simple_chatroom.html | 137 +++++++++++++++++++++++++++++---- util/custom_logger.py | 0 util/logger.py | 95 ----------------------- 8 files changed, 269 insertions(+), 146 deletions(-) create mode 100644 util/custom_logger.py delete mode 100644 util/logger.py diff --git a/.gitignore b/.gitignore index 4be3071..e139717 100644 --- a/.gitignore +++ b/.gitignore @@ -165,4 +165,8 @@ cython_debug/ **/*-local.* **/*-local **/*_local.* -**/*_local \ No newline at end of file +**/*_local + +# static and media files +static/** +media/** \ No newline at end of file diff --git a/app.py b/app.py index d798365..aba4562 100644 --- a/app.py +++ b/app.py @@ -1,10 +1,11 @@ +import os.path from contextlib import asynccontextmanager import uvicorn from fastapi import FastAPI from fastapi.exceptions import RequestValidationError, ResponseValidationError from fastapi.middleware.cors import CORSMiddleware -from starlette.middleware.authentication import AuthenticationMiddleware +from starlette.staticfiles import StaticFiles from util.logger import init_logger @@ -17,7 +18,7 @@ def register_router(_app: FastAPI): _app.include_router(api_router.router, prefix="/api") -def add_custom_exception_handlers(_app: FastAPI): +def register_custom_exception_handlers(_app: FastAPI): from util.exception_util import request_validation_exception_handler, response_validation_exception_handler _app.add_exception_handler(RequestValidationError, request_validation_exception_handler) _app.add_exception_handler(ResponseValidationError, response_validation_exception_handler) @@ -33,11 +34,19 @@ def register_middlewares(_app: FastAPI): ) +def register_mounter(_app: FastAPI): + os.makedirs("static", exist_ok=True) + os.makedirs("media", exist_ok=True) + _app.mount("/static", StaticFiles(directory="static"), name="static") + _app.mount("/media", StaticFiles(directory="media"), name="media") + + def create_app(span) -> FastAPI: _app = FastAPI(lifespan=span) register_router(_app) register_middlewares(_app) - add_custom_exception_handlers(_app) + register_custom_exception_handlers(_app) + register_mounter(_app) return _app diff --git a/core/chat.py b/core/chat.py index 184007f..de310b5 100644 --- a/core/chat.py +++ b/core/chat.py @@ -1,33 +1,34 @@ +import hashlib import json +import os import random +import shutil import string -from datetime import timedelta -from typing import List, Dict +from datetime import datetime, timedelta + +import starlette.datastructures from aioredis import Redis -from fastapi import APIRouter, Depends -from starlette.websockets import WebSocket, WebSocketDisconnect +from fastapi.exceptions import RequestValidationError +from loguru import logger +from fastapi import APIRouter, Depends, File, UploadFile, WebSocket, WebSocketDisconnect + from .deps import get_redis_session -from datetime import datetime -clients: List[WebSocket] = [] -client_map: Dict[str, Dict[str, str]] = {} # {websocket_id: {"username": username, "ip": ip}} +# 存储 WebSocket 连接和映射关系 +clients: list[WebSocket] = [] +client_map: dict[WebSocket, dict[str, str]] = {} router = APIRouter() # 生成随机用户名 def generate_random_username() -> str: - return ''.join(random.choices(string.ascii_letters + string.digits, k=8)) + return "".join(random.choices(string.ascii_letters + string.digits, k=8)) # 消息存储到 Redis -async def store_message_in_redis(message: str, username: str, ip: str, session: Redis): - message_data = { - "username": username, - "ip": ip, - "timestamp": datetime.now().isoformat(), - "message": message - } +async def store_message_in_redis(message: str | dict, username: str, ip: str, session: Redis): + message_data = {"username": username, "ip": ip, "timestamp": datetime.now().isoformat(), "message": message} # 将消息数据以 JSON 字符串的形式存储到 Redis 列表 await session.lpush("chat_history", json.dumps(message_data)) # 设置消息过期时间为7天 @@ -40,12 +41,33 @@ async def get_chat_history(session: Redis): return [json.loads(msg) for msg in reversed(messages)] +# 使用 SHA-256 对文件名进行加密 +def encrypt_filename(filename: str) -> str: + # 使用 SHA-256 对文件名进行哈希加密 + sha256_hash = hashlib.sha256() + sha256_hash.update(filename.encode('utf-8')) # 对文件名进行编码 + return sha256_hash.hexdigest() # 返回加密后的文件名 + + +# 上传文件处理,文件大小小于50MB时通过WebSocket直接传输,超过50MB时使用POST上传 +async def handle_file_upload(file: UploadFile | bytes, file_name: str = None) -> str: + # 获取文件的扩展名 + file_name, ext = os.path.splitext(file_name or file.filename) + file_name = encrypt_filename(file_name) or encrypt_filename(file.filename) + filename = f"{file_name}{ext}" + file_location = f"/media/uploads/{filename}" + + # 保存文件到指定目录 + os.makedirs(os.path.dirname(file_location.lstrip("/")), exist_ok=True) + with open(file_location.lstrip("/"), "wb") as f: + shutil.copyfileobj(file.file, f) if isinstance(file, starlette.datastructures.UploadFile) else f.write(file) + + return file_location + + # WebSocket 连接处理 @router.websocket("/ws/chat") -async def websocket_endpoint( - websocket: WebSocket, - rds_session: Redis = Depends(get_redis_session) -): +async def websocket_endpoint(websocket: WebSocket, rds_session: Redis = Depends(get_redis_session)): await websocket.accept() clients.append(websocket) @@ -59,29 +81,58 @@ async def websocket_endpoint( # 向客户端发送历史消息 chat_history = await get_chat_history(rds_session) for message in chat_history: - await websocket.send_text(f"{message['timestamp']} - {message['username']} ({message['ip']}): {message['message']}") + await websocket.send_json(message) try: while True: # 接收客户端发送的消息 - message = await websocket.receive_text() + origin_message = await websocket.receive_json() + + if origin_message.get("type") == "file": + if origin_message.get("filename") and origin_message.get("fileSize") < 50 * 1024 * 1024: + # 接收文件二进制数据 + file_data = await websocket.receive_bytes() + if file_data: + file_location = await handle_file_upload( + file_data, + file_name=origin_message.get("filename") + ) + origin_message["url"] = file_location + logger.info(f"File saved: {file_location}") + logger.warning("File size exceeds 50MB, please use POST to upload.") + raise RequestValidationError("File size exceeds 50MB, please use POST to upload.") # 获取客户端的用户名和 IP 地址 client_username = client_map[websocket]["username"] client_real_ip = client_map[websocket]["ip"] # 将消息存储到 Redis - await store_message_in_redis(message, client_username, client_real_ip, rds_session) + await store_message_in_redis(origin_message, client_username, client_real_ip, rds_session) # 向所有连接的客户端广播消息 for client in clients: - # 获取发送者的用户名 - sender_username = client_map[client]["username"] - sender_ip = client_map[client]["ip"] - # 发送时,附加用户名、IP 和时间戳 - await client.send_text(f"{datetime.now().isoformat()} - {sender_username} ({sender_ip}): {message}") + await client.send_json({ + "username": client_map[client]["username"], + "ip": client_map[client]["ip"], + "timestamp": datetime.now().isoformat(), + "message": origin_message + }) except WebSocketDisconnect: # 断开连接时,清除客户端映射关系 del client_map[websocket] clients.remove(websocket) + + +# 处理文件上传的HTTP POST请求 +@router.post("/upload") +async def upload_file(file: UploadFile = File(...)): + try: + # 处理文件上传,并返回文件存储位置 + file_location = await handle_file_upload(file) + + # 返回文件的 URL + return {"fileUrl": f"{file_location}"} + + except ValueError as e: + return {"error": str(e)} diff --git a/poetry.lock b/poetry.lock index 4d3ae6a..6707a08 100644 --- a/poetry.lock +++ b/poetry.lock @@ -399,6 +399,25 @@ MarkupSafe = ">=2.0" [package.extras] i18n = ["Babel (>=2.7)"] +[[package]] +name = "loguru" +version = "0.7.3" +description = "Python logging made (stupidly) simple" +optional = false +python-versions = "<4.0,>=3.5" +groups = ["main"] +files = [ + {file = "loguru-0.7.3-py3-none-any.whl", hash = "sha256:31a33c10c8e1e10422bfd431aeb5d351c7cf7fa671e3c4df004162264b28220c"}, + {file = "loguru-0.7.3.tar.gz", hash = "sha256:19480589e77d47b8d85b2c827ad95d49bf31b0dcde16593892eb51dd18706eb6"}, +] + +[package.dependencies] +colorama = {version = ">=0.3.4", markers = "sys_platform == \"win32\""} +win32-setctime = {version = ">=1.0.0", markers = "sys_platform == \"win32\""} + +[package.extras] +dev = ["Sphinx (==8.1.3)", "build (==1.2.2)", "colorama (==0.4.5)", "colorama (==0.4.6)", "exceptiongroup (==1.1.3)", "freezegun (==1.1.0)", "freezegun (==1.5.0)", "mypy (==v0.910)", "mypy (==v0.971)", "mypy (==v1.13.0)", "mypy (==v1.4.1)", "myst-parser (==4.0.0)", "pre-commit (==4.0.1)", "pytest (==6.1.2)", "pytest (==8.3.2)", "pytest-cov (==2.12.1)", "pytest-cov (==5.0.0)", "pytest-cov (==6.0.0)", "pytest-mypy-plugins (==1.9.3)", "pytest-mypy-plugins (==3.1.0)", "sphinx-rtd-theme (==3.0.2)", "tox (==3.27.1)", "tox (==4.23.2)", "twine (==6.0.1)"] + [[package]] name = "markupsafe" version = "3.0.2" @@ -640,6 +659,18 @@ files = [ [package.extras] cli = ["click (>=5.0)"] +[[package]] +name = "python-multipart" +version = "0.0.20" +description = "A streaming multipart parser for Python" +optional = false +python-versions = ">=3.8" +groups = ["main"] +files = [ + {file = "python_multipart-0.0.20-py3-none-any.whl", hash = "sha256:8a62d3a8335e06589fe01f2a3e178cdcc632f3fbe0d492ad9ee0ec35aab1f104"}, + {file = "python_multipart-0.0.20.tar.gz", hash = "sha256:8dd0cab45b8e23064ae09147625994d090fa46f5b0d1e13af944c331a7fa9d13"}, +] + [[package]] name = "pyyaml" version = "6.0.2" @@ -1086,7 +1117,23 @@ files = [ {file = "websockets-14.1.tar.gz", hash = "sha256:398b10c77d471c0aab20a845e7a60076b6390bfdaac7a6d2edb0d2c59d75e8d8"}, ] +[[package]] +name = "win32-setctime" +version = "1.2.0" +description = "A small Python utility to set file creation time on Windows" +optional = false +python-versions = ">=3.5" +groups = ["main"] +markers = "sys_platform == \"win32\"" +files = [ + {file = "win32_setctime-1.2.0-py3-none-any.whl", hash = "sha256:95d644c4e708aba81dc3704a116d8cbc974d70b3bdb8be1d150e36be6e9d1390"}, + {file = "win32_setctime-1.2.0.tar.gz", hash = "sha256:ae1fdf948f5640aae05c511ade119313fb6a30d7eabe25fef9764dca5873c4c0"}, +] + +[package.extras] +dev = ["black (>=19.3b0)", "pytest (>=4.6.2)"] + [metadata] lock-version = "2.1" -python-versions = ">=3.10" -content-hash = "7500d0c011e40c611a9e8125d032d8742da34a1416bd8f20a00d67a72dbcfb9c" +python-versions = ">=3.10,<4.0" +content-hash = "e7319ed2a033ba23113526679d8989cc2536794f2c46dedc035acffef1f9ccd7" diff --git a/pyproject.toml b/pyproject.toml index b80889b..3c832d7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ authors = [ ] license = {text = "MIT"} readme = "README.md" -requires-python = ">=3.10" +requires-python = ">=3.10,<4.0" dependencies = [ "fastapi (>=0.115.6,<0.116.0)", "uvicorn[standard] (>=0.34.0,<0.35.0)", @@ -16,7 +16,9 @@ dependencies = [ "sqlmodel (>=0.0.22,<0.0.23)", "asyncpg (>=0.30.0,<0.31.0)", "aiofiles (>=24.1.0,<25.0.0)", - "jinja2 (>=3.1.5,<4.0.0)" + "jinja2 (>=3.1.5,<4.0.0)", + "python-multipart (>=0.0.20,<0.0.21)", + "loguru (>=0.7.3,<0.8.0)" ] diff --git a/templates/simple_chatroom.html b/templates/simple_chatroom.html index a0a79be..445f83f 100644 --- a/templates/simple_chatroom.html +++ b/templates/simple_chatroom.html @@ -3,31 +3,136 @@ - FastAPI WebSocket Chat + WebSocket Chat with File Upload + -

FastAPI WebSocket Chat

+

WebSocket Chat with File Upload

- - + + + +
diff --git a/util/custom_logger.py b/util/custom_logger.py new file mode 100644 index 0000000..e69de29 diff --git a/util/logger.py b/util/logger.py deleted file mode 100644 index f035dc8..0000000 --- a/util/logger.py +++ /dev/null @@ -1,95 +0,0 @@ -import logging -import os -from collections.abc import Callable -from enum import Enum -from functools import wraps -from logging.handlers import TimedRotatingFileHandler - -from base.config import config -from exc.database_exc import DatabaseErr -from exc.service_exc import ServiceErr - - -class ColoredFormatter(logging.Formatter): - COLORS = { - "DEBUG": "\033[94m", # 蓝色 - "INFO": "\033[92m", # 绿色 - "WARNING": "\033[93m", # 黄色 - "ERROR": "\033[91m", # 红色 - "CRITICAL": "\033[91m", # 红色 - } - RESET = "\033[0m" - - def format(self, record): - level_name = record.levelname - msg = super().format(record) - colored_level_name = self.COLORS.get(level_name, self.RESET) + level_name + self.RESET - return f"{colored_level_name}: {msg}" - - -def init_logger(): - # 创建一个输出到控制台的处理器 - console_handler = logging.StreamHandler() - console_formatter = ColoredFormatter("%(message)s") - console_handler.setFormatter(console_formatter) - - # 创建一个按天分割输出到文件的日志处理器 - log_file = os.path.abspath(config.base.LOG_PATH) - log_dir = os.path.dirname(log_file) - if not os.path.exists(log_dir): - os.makedirs(log_dir) - file_handler = TimedRotatingFileHandler(log_file, when="midnight", interval=1, backupCount=7, encoding="utf-8") - file_formatter = logging.Formatter("%(asctime)s [%(levelname)s] [%(name)s]: %(message)s") - file_handler.setFormatter(file_formatter) - - # 使用basicConfig设置全局配置 - logging.basicConfig(level=config.base.LOG_LEVEL, handlers=[console_handler, file_handler]) - - # uvicorn启动日志输出到文件 - logging.getLogger("uvicorn").addHandler(file_handler) - # fastapi请求日志输出到文件 - logging.getLogger("uvicorn.access").addHandler(file_handler) - - -class LogType(Enum): - INFO = logging.Logger.info - ERROR = logging.Logger.error - WARNING = logging.Logger.warning - DEBUG = logging.Logger.debug - CRITICAL = logging.Logger.critical - - -def log_and_raise( - log_type: Callable = LogType.ERROR, - logging_logger: logging.Logger = None, - log_msg: str = None, - raise_exc: Exception = None, -): - """ - Decorator for logging and raising exceptions - """ - - def decorator(func): - @wraps(func) - async def wrapper(*args, **kwargs): - async def log_or_raise(log: bool = True, exc: Exception = None): - """ - Log or raise exceptions by given parameters - """ - (await kwargs["session"].rollback()) if kwargs.get("session", None) else None - log_type(logging_logger, f"{func.__name__} Error: {log_msg or exc}") if log else None - if exc: - if raise_exc: - raise raise_exc from exc - raise exc - - try: - return await func(*args, **kwargs) - except (ServiceErr, DatabaseErr) as e: - raise await log_or_raise(log=False, exc=e) - except Exception as e: - raise await log_or_raise(exc=e) - - return wrapper - - return decorator From 05fdb44b0008b77e72153f3bee4cba551b0a2f36 Mon Sep 17 00:00:00 2001 From: VincentAdamNemessis Date: Sat, 11 Jan 2025 15:50:30 +0800 Subject: [PATCH 3/4] style: add pre-commit --- app.py | 17 +++--- core/chat.py | 30 +++++----- core/page_router.py | 1 + exc/database_exc.py | 6 +- exc/service_exc.py | 6 +- poetry.lock | 127 ++++++++++++++++++++++++++++++++++++++++- pyproject.toml | 116 ++++++++++++++++++++++++++++++++++++- util/exception_util.py | 19 +++--- 8 files changed, 285 insertions(+), 37 deletions(-) diff --git a/app.py b/app.py index aba4562..f98686b 100644 --- a/app.py +++ b/app.py @@ -7,19 +7,20 @@ from fastapi.middleware.cors import CORSMiddleware from starlette.staticfiles import StaticFiles -from util.logger import init_logger - -init_logger() - def register_router(_app: FastAPI): from core import api_router, page_router + _app.include_router(page_router.router) _app.include_router(api_router.router, prefix="/api") def register_custom_exception_handlers(_app: FastAPI): - from util.exception_util import request_validation_exception_handler, response_validation_exception_handler + from util.exception_util import ( + request_validation_exception_handler, + response_validation_exception_handler, + ) + _app.add_exception_handler(RequestValidationError, request_validation_exception_handler) _app.add_exception_handler(ResponseValidationError, response_validation_exception_handler) @@ -51,10 +52,12 @@ def create_app(span) -> FastAPI: @asynccontextmanager -async def lifespan(application: FastAPI): # noqa +async def lifespan(application: FastAPI): from base.connector import database_connector + """ - Use context manager to manage the lifespan of the application instead of using the startup and shutdown events. + Use context manager to manage the lifespan of the application instead of + using the startup and shutdown events. """ yield await database_connector.engine.dispose() diff --git a/core/chat.py b/core/chat.py index de310b5..137b0e9 100644 --- a/core/chat.py +++ b/core/chat.py @@ -8,9 +8,9 @@ import starlette.datastructures from aioredis import Redis +from fastapi import APIRouter, Depends, File, UploadFile, WebSocket, WebSocketDisconnect from fastapi.exceptions import RequestValidationError from loguru import logger -from fastapi import APIRouter, Depends, File, UploadFile, WebSocket, WebSocketDisconnect from .deps import get_redis_session @@ -28,7 +28,12 @@ def generate_random_username() -> str: # 消息存储到 Redis async def store_message_in_redis(message: str | dict, username: str, ip: str, session: Redis): - message_data = {"username": username, "ip": ip, "timestamp": datetime.now().isoformat(), "message": message} + message_data = { + "username": username, + "ip": ip, + "timestamp": datetime.now().isoformat(), + "message": message, + } # 将消息数据以 JSON 字符串的形式存储到 Redis 列表 await session.lpush("chat_history", json.dumps(message_data)) # 设置消息过期时间为7天 @@ -45,7 +50,7 @@ async def get_chat_history(session: Redis): def encrypt_filename(filename: str) -> str: # 使用 SHA-256 对文件名进行哈希加密 sha256_hash = hashlib.sha256() - sha256_hash.update(filename.encode('utf-8')) # 对文件名进行编码 + sha256_hash.update(filename.encode("utf-8")) # 对文件名进行编码 return sha256_hash.hexdigest() # 返回加密后的文件名 @@ -93,10 +98,7 @@ async def websocket_endpoint(websocket: WebSocket, rds_session: Redis = Depends( # 接收文件二进制数据 file_data = await websocket.receive_bytes() if file_data: - file_location = await handle_file_upload( - file_data, - file_name=origin_message.get("filename") - ) + file_location = await handle_file_upload(file_data, file_name=origin_message.get("filename")) origin_message["url"] = file_location logger.info(f"File saved: {file_location}") logger.warning("File size exceeds 50MB, please use POST to upload.") @@ -111,12 +113,14 @@ async def websocket_endpoint(websocket: WebSocket, rds_session: Redis = Depends( # 向所有连接的客户端广播消息 for client in clients: - await client.send_json({ - "username": client_map[client]["username"], - "ip": client_map[client]["ip"], - "timestamp": datetime.now().isoformat(), - "message": origin_message - }) + await client.send_json( + { + "username": client_map[client]["username"], + "ip": client_map[client]["ip"], + "timestamp": datetime.now().isoformat(), + "message": origin_message, + } + ) except WebSocketDisconnect: # 断开连接时,清除客户端映射关系 diff --git a/core/page_router.py b/core/page_router.py index 4c40e61..3ad9d46 100644 --- a/core/page_router.py +++ b/core/page_router.py @@ -1,4 +1,5 @@ from fastapi import APIRouter + from .page import router as page_router router = APIRouter() diff --git a/exc/database_exc.py b/exc/database_exc.py index d4d8242..744d7a8 100644 --- a/exc/database_exc.py +++ b/exc/database_exc.py @@ -1,16 +1,16 @@ from typing import Any -class DatabaseErr(Exception): +class DatabaseError(Exception): def __init__(self, message: str = ""): super().__init__(f"{message}") -class NotFoundRecordsErr(DatabaseErr): +class NotFoundRecordsError(DatabaseError): def __init__(self, reason: Any = None): super().__init__(f"{f'{reason}' if reason else '.'}") -class IntegrityErr(DatabaseErr): +class IntegrityError(DatabaseError): def __init__(self, reason: Any = None): super().__init__(f"Record(s) Integrity Error{f': `{reason}`' if reason else '.'}") diff --git a/exc/service_exc.py b/exc/service_exc.py index 4887f01..26daa30 100644 --- a/exc/service_exc.py +++ b/exc/service_exc.py @@ -1,16 +1,16 @@ from builtins import Exception -class ServiceErr(Exception): +class ServiceError(Exception): def __init__(self, message: str | dict = ""): super().__init__(message) -class BadRequestErr(ServiceErr): +class BadRequestError(ServiceError): def __init__(self, message: str | dict = ""): super().__init__(message) -class NotFoundErr(ServiceErr): +class NotFoundError(ServiceError): def __init__(self, message: str | dict = ""): super().__init__(message) diff --git a/poetry.lock b/poetry.lock index 6707a08..536c8a5 100644 --- a/poetry.lock +++ b/poetry.lock @@ -145,6 +145,18 @@ docs = ["Sphinx (>=8.1.3,<8.2.0)", "sphinx-rtd-theme (>=1.2.2)"] gssauth = ["gssapi", "sspilib"] test = ["distro (>=1.9.0,<1.10.0)", "flake8 (>=6.1,<7.0)", "flake8-pyi (>=24.1.0,<24.2.0)", "gssapi", "k5test", "mypy (>=1.8.0,<1.9.0)", "sspilib", "uvloop (>=0.15.3)"] +[[package]] +name = "cfgv" +version = "3.4.0" +description = "Validate configuration and produce human readable error messages." +optional = false +python-versions = ">=3.8" +groups = ["main"] +files = [ + {file = "cfgv-3.4.0-py2.py3-none-any.whl", hash = "sha256:b7265b1f29fd3316bfcd2b330d63d024f2bfd8bcb8b0272f8e19a504856c48f9"}, + {file = "cfgv-3.4.0.tar.gz", hash = "sha256:e52591d4c5f5dead8e0f673fb16db7949d2cfb3f7da4582893288f0ded8fe560"}, +] + [[package]] name = "click" version = "8.1.8" @@ -173,6 +185,18 @@ files = [ {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, ] +[[package]] +name = "distlib" +version = "0.3.9" +description = "Distribution utilities" +optional = false +python-versions = "*" +groups = ["main"] +files = [ + {file = "distlib-0.3.9-py2.py3-none-any.whl", hash = "sha256:47f8c22fd27c27e25a65601af709b38e4f0a45ea4fc2e710f65755fa8caaaf87"}, + {file = "distlib-0.3.9.tar.gz", hash = "sha256:a60f20dea646b8a33f3e7772f74dc0b2d0772d2837ee1342a00645c81edf9403"}, +] + [[package]] name = "exceptiongroup" version = "1.2.2" @@ -210,6 +234,23 @@ typing-extensions = ">=4.8.0" all = ["email-validator (>=2.0.0)", "fastapi-cli[standard] (>=0.0.5)", "httpx (>=0.23.0)", "itsdangerous (>=1.1.0)", "jinja2 (>=2.11.2)", "orjson (>=3.2.1)", "pydantic-extra-types (>=2.0.0)", "pydantic-settings (>=2.0.0)", "python-multipart (>=0.0.7)", "pyyaml (>=5.3.1)", "ujson (>=4.0.1,!=4.0.2,!=4.1.0,!=4.2.0,!=4.3.0,!=5.0.0,!=5.1.0)", "uvicorn[standard] (>=0.12.0)"] standard = ["email-validator (>=2.0.0)", "fastapi-cli[standard] (>=0.0.5)", "httpx (>=0.23.0)", "jinja2 (>=2.11.2)", "python-multipart (>=0.0.7)", "uvicorn[standard] (>=0.12.0)"] +[[package]] +name = "filelock" +version = "3.16.1" +description = "A platform independent file lock." +optional = false +python-versions = ">=3.8" +groups = ["main"] +files = [ + {file = "filelock-3.16.1-py3-none-any.whl", hash = "sha256:2082e5703d51fbf98ea75855d9d5527e33d8ff23099bec374a134febee6946b0"}, + {file = "filelock-3.16.1.tar.gz", hash = "sha256:c249fbfcd5db47e5e2d6d62198e565475ee65e4831e2561c8e313fa7eb961435"}, +] + +[package.extras] +docs = ["furo (>=2024.8.6)", "sphinx (>=8.0.2)", "sphinx-autodoc-typehints (>=2.4.1)"] +testing = ["covdefaults (>=2.3)", "coverage (>=7.6.1)", "diff-cover (>=9.2)", "pytest (>=8.3.3)", "pytest-asyncio (>=0.24)", "pytest-cov (>=5)", "pytest-mock (>=3.14)", "pytest-timeout (>=2.3.1)", "virtualenv (>=20.26.4)"] +typing = ["typing-extensions (>=4.12.2)"] + [[package]] name = "greenlet" version = "3.1.1" @@ -366,6 +407,21 @@ files = [ [package.extras] test = ["Cython (>=0.29.24)"] +[[package]] +name = "identify" +version = "2.6.5" +description = "File identification library for Python" +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "identify-2.6.5-py2.py3-none-any.whl", hash = "sha256:14181a47091eb75b337af4c23078c9d09225cd4c48929f521f3bf16b09d02566"}, + {file = "identify-2.6.5.tar.gz", hash = "sha256:c10b33f250e5bba374fae86fb57f3adcebf1161bce7cdf92031915fd480c13bc"}, +] + +[package.extras] +license = ["ukkonen"] + [[package]] name = "idna" version = "3.10" @@ -489,6 +545,54 @@ files = [ {file = "markupsafe-3.0.2.tar.gz", hash = "sha256:ee55d3edf80167e48ea11a923c7386f4669df67d7994554387f84e7d8b0a2bf0"}, ] +[[package]] +name = "nodeenv" +version = "1.9.1" +description = "Node.js virtual environment builder" +optional = false +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" +groups = ["main"] +files = [ + {file = "nodeenv-1.9.1-py2.py3-none-any.whl", hash = "sha256:ba11c9782d29c27c70ffbdda2d7415098754709be8a7056d79a737cd901155c9"}, + {file = "nodeenv-1.9.1.tar.gz", hash = "sha256:6ec12890a2dab7946721edbfbcd91f3319c6ccc9aec47be7c7e6b7011ee6645f"}, +] + +[[package]] +name = "platformdirs" +version = "4.3.6" +description = "A small Python package for determining appropriate platform-specific dirs, e.g. a `user data dir`." +optional = false +python-versions = ">=3.8" +groups = ["main"] +files = [ + {file = "platformdirs-4.3.6-py3-none-any.whl", hash = "sha256:73e575e1408ab8103900836b97580d5307456908a03e92031bab39e4554cc3fb"}, + {file = "platformdirs-4.3.6.tar.gz", hash = "sha256:357fb2acbc885b0419afd3ce3ed34564c13c9b95c89360cd9563f73aa5e2b907"}, +] + +[package.extras] +docs = ["furo (>=2024.8.6)", "proselint (>=0.14)", "sphinx (>=8.0.2)", "sphinx-autodoc-typehints (>=2.4)"] +test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=8.3.2)", "pytest-cov (>=5)", "pytest-mock (>=3.14)"] +type = ["mypy (>=1.11.2)"] + +[[package]] +name = "pre-commit" +version = "4.0.1" +description = "A framework for managing and maintaining multi-language pre-commit hooks." +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "pre_commit-4.0.1-py2.py3-none-any.whl", hash = "sha256:efde913840816312445dc98787724647c65473daefe420785f885e8ed9a06878"}, + {file = "pre_commit-4.0.1.tar.gz", hash = "sha256:80905ac375958c0444c65e9cebebd948b3cdb518f335a091a670a89d652139d2"}, +] + +[package.dependencies] +cfgv = ">=2.0.0" +identify = ">=1.0.0" +nodeenv = ">=0.11.1" +pyyaml = ">=5.1" +virtualenv = ">=20.10.0" + [[package]] name = "pydantic" version = "2.10.5" @@ -954,6 +1058,27 @@ dev = ["Cython (>=3.0,<4.0)", "setuptools (>=60)"] docs = ["Sphinx (>=4.1.2,<4.2.0)", "sphinx-rtd-theme (>=0.5.2,<0.6.0)", "sphinxcontrib-asyncio (>=0.3.0,<0.4.0)"] test = ["aiohttp (>=3.10.5)", "flake8 (>=5.0,<6.0)", "mypy (>=0.800)", "psutil", "pyOpenSSL (>=23.0.0,<23.1.0)", "pycodestyle (>=2.9.0,<2.10.0)"] +[[package]] +name = "virtualenv" +version = "20.28.1" +description = "Virtual Python Environment builder" +optional = false +python-versions = ">=3.8" +groups = ["main"] +files = [ + {file = "virtualenv-20.28.1-py3-none-any.whl", hash = "sha256:412773c85d4dab0409b83ec36f7a6499e72eaf08c80e81e9576bca61831c71cb"}, + {file = "virtualenv-20.28.1.tar.gz", hash = "sha256:5d34ab240fdb5d21549b76f9e8ff3af28252f5499fb6d6f031adac4e5a8c5329"}, +] + +[package.dependencies] +distlib = ">=0.3.7,<1" +filelock = ">=3.12.2,<4" +platformdirs = ">=3.9.1,<5" + +[package.extras] +docs = ["furo (>=2023.7.26)", "proselint (>=0.13)", "sphinx (>=7.1.2,!=7.3)", "sphinx-argparse (>=0.4)", "sphinxcontrib-towncrier (>=0.2.1a0)", "towncrier (>=23.6)"] +test = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "coverage-enable-subprocess (>=1)", "flaky (>=3.7)", "packaging (>=23.1)", "pytest (>=7.4)", "pytest-env (>=0.8.2)", "pytest-freezer (>=0.4.8)", "pytest-mock (>=3.11.1)", "pytest-randomly (>=3.12)", "pytest-timeout (>=2.1)", "setuptools (>=68)", "time-machine (>=2.10)"] + [[package]] name = "watchfiles" version = "1.0.3" @@ -1136,4 +1261,4 @@ dev = ["black (>=19.3b0)", "pytest (>=4.6.2)"] [metadata] lock-version = "2.1" python-versions = ">=3.10,<4.0" -content-hash = "e7319ed2a033ba23113526679d8989cc2536794f2c46dedc035acffef1f9ccd7" +content-hash = "1d16d22a79596ea34f6cadff7cba8958db59ee24b04f88fe8fdfdd17dcf0d70e" diff --git a/pyproject.toml b/pyproject.toml index 3c832d7..a6fc046 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,10 +18,124 @@ dependencies = [ "aiofiles (>=24.1.0,<25.0.0)", "jinja2 (>=3.1.5,<4.0.0)", "python-multipart (>=0.0.20,<0.0.21)", - "loguru (>=0.7.3,<0.8.0)" + "loguru (>=0.7.3,<0.8.0)", + "pre-commit (>=4.0.1,<5.0.0)" ] [build-system] requires = ["poetry-core>=2.0.0,<3.0.0"] build-backend = "poetry.core.masonry.api" + + +[tool.ruff] +# Exclude a variety of commonly ignored directories. +exclude = [ + ".bzr", + ".direnv", + ".eggs", + ".git", + ".git-rewrite", + ".hg", + ".ipynb_checkpoints", + ".mypy_cache", + ".nox", + ".pants.d", + ".pyenv", + ".pytest_cache", + ".pytype", + ".ruff_cache", + ".svn", + ".tox", + ".venv", + ".vscode", + "__pypackages__", + "_build", + "buck-out", + "build", + "dist", + "node_modules", + "site-packages", + "venv", + "alembic", +] + + +line-length = 120 +indent-width = 4 + +target-version = "py310" + +[tool.ruff.lint] +select = [ + "A", # flake8-annotations + "B", # flake8-bugbear rules + "F", # pyflakes rules + "N", # name style rules + "I", # isort rules + "UP", # pyupgrade rules + "E101", # mixed-spaces-and-tabs + "E111", # indentation-with-invalid-multiple + "E112", # no-indented-block + "E113", # unexpected-indentation + "E115", # no-indented-block-comment + "E116", # unexpected-indentation-comment + "E117", # over-indented + "RUF019", # unnecessary-key-check + "RUF100", # unused-noqa + "RUF101", # redirected-noqa + "S506", # unsafe-yaml-load + "W191", # tab-indentation + "W605", # invalid-escape-sequence +] +ignore = [ + "B904", # raise-without-from-inside-except + "F811", # redefinition-of-unused +] + +fixable = ["ALL"] +unfixable = [] +# Allow unused variables when underscore-prefixed. +dummy-variable-rgx = "^(_[a-zA-Z0-9_]*|)$" + +[tool.ruff.format] +# Like Black, use double quotes for strings. +quote-style = "double" + +# Like Black, indent with spaces, rather than tabs. +indent-style = "space" + +# Like Black, respect magic trailing commas. +skip-magic-trailing-comma = false + + +[tool.ruff.lint.flake8-bugbear] +extend-immutable-calls = [ + "fastapi.Depends", + "fastapi.params.Depends", + "fastapi.Query", + "fastapi.params.Query", + "fastapi.Path", + "fastapi.params.Path", + "fastapi.Body", + "fastapi.params.Body", + "fastapi.Form", + "fastapi.params.Form", + "fastapi.Header", + "fastapi.params.Header", + "fastapi.File", + "fastapi.params.File", + "fastapi.Cookie", + "fastapi.params.Cookie", + "fastapi.Security", + "fastapi.params.Security", +] + +[tool.ruff.lint.mccabe] +max-complexity = 5 + +[tool.pylint] + +disable = ["all"] # diable all rule first +enable = ["too-many-statements"] # then enable too-many-statements rule +max-statements = 50 # function max statement \ No newline at end of file diff --git a/util/exception_util.py b/util/exception_util.py index b4ac034..5a7ecdd 100644 --- a/util/exception_util.py +++ b/util/exception_util.py @@ -1,18 +1,19 @@ import logging -from exc.database_exc import DatabaseErr, NotFoundRecordsErr -from exc.http_exc import ( - BadRequestExc, - BaseHttpExc, - NotFoundExc, -) -from exc.service_exc import BadRequestErr, NotFoundErr from fastapi import Request, Response from fastapi.exceptions import RequestValidationError, ResponseValidationError from fastapi.responses import JSONResponse from sqlmodel.ext.asyncio.session import AsyncSession from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY +from exc.database_exc import DatabaseError, NotFoundRecordsError +from exc.http_exc import ( + BadRequestExc, + BaseHttpExc, + NotFoundExc, +) +from exc.service_exc import BadRequestError, NotFoundError + logger = logging.getLogger(__name__) @@ -23,9 +24,9 @@ async def handle_exception(e: Exception, session: AsyncSession | None = None) -> def raise_bad_request_exception(detail: str) -> None: raise BadRequestExc(detail=detail) - if isinstance(e, NotFoundErr | NotFoundRecordsErr): + if isinstance(e, NotFoundError | NotFoundRecordsError): raise NotFoundExc(detail=str(e)) - elif isinstance(e, BadRequestErr) or isinstance(e, DatabaseErr): + elif isinstance(e, BadRequestError) or isinstance(e, DatabaseError): raise_bad_request_exception(str(e)) elif isinstance(e, BaseHttpExc) or isinstance(e, RequestValidationError): logger.error(f"Error: {str(e)}") From 52ed04e2d7cd40ceece905871ab6403d73de87a9 Mon Sep 17 00:00:00 2001 From: VincentAdamNemessis Date: Mon, 17 Feb 2025 14:57:58 +0800 Subject: [PATCH 4/4] fix: some incorrect logic and bug --- core/chat.py | 21 ++++++++++++--------- templates/simple_chatroom.html | 2 +- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/core/chat.py b/core/chat.py index 137b0e9..16f1a1b 100644 --- a/core/chat.py +++ b/core/chat.py @@ -9,7 +9,6 @@ import starlette.datastructures from aioredis import Redis from fastapi import APIRouter, Depends, File, UploadFile, WebSocket, WebSocketDisconnect -from fastapi.exceptions import RequestValidationError from loguru import logger from .deps import get_redis_session @@ -92,17 +91,21 @@ async def websocket_endpoint(websocket: WebSocket, rds_session: Redis = Depends( while True: # 接收客户端发送的消息 origin_message = await websocket.receive_json() - if origin_message.get("type") == "file": if origin_message.get("filename") and origin_message.get("fileSize") < 50 * 1024 * 1024: - # 接收文件二进制数据 file_data = await websocket.receive_bytes() - if file_data: - file_location = await handle_file_upload(file_data, file_name=origin_message.get("filename")) - origin_message["url"] = file_location - logger.info(f"File saved: {file_location}") - logger.warning("File size exceeds 50MB, please use POST to upload.") - raise RequestValidationError("File size exceeds 50MB, please use POST to upload.") + print(encrypt_filename(origin_message.get("filename"))) + if not os.path.exists(f"/media/uploads/{encrypt_filename(origin_message.get('filename'))}"): + # 接收文件二进制数据 + if file_data: + file_location = await handle_file_upload( + file_data, file_name=origin_message.get("filename") + ) + origin_message["url"] = file_location + logger.info(f"File saved: {file_location}") + else: + logger.warning("File size exceeds 50MB, please use POST to upload.") + await websocket.send_json({"error": "File size exceeds 50MB, please use POST to upload."}) # 获取客户端的用户名和 IP 地址 client_username = client_map[websocket]["username"] diff --git a/templates/simple_chatroom.html b/templates/simple_chatroom.html index 445f83f..72f2050 100644 --- a/templates/simple_chatroom.html +++ b/templates/simple_chatroom.html @@ -24,7 +24,7 @@

WebSocket Chat with File Upload