diff --git a/src/memos/api/config.py b/src/memos/api/config.py index 3c1ad959b..adbd04e3c 100644 --- a/src/memos/api/config.py +++ b/src/memos/api/config.py @@ -821,7 +821,12 @@ def get_product_default_config() -> dict[str, Any]: "oss_config": APIConfig.get_oss_config(), "skills_dir_config": { "skills_oss_dir": os.getenv("SKILLS_OSS_DIR", "skill_memory/"), - "skills_local_dir": os.getenv("SKILLS_LOCAL_DIR", "/tmp/skill_memory/"), + "skills_local_tmp_dir": os.getenv( + "SKILLS_LOCAL_TMP_DIR", "/tmp/skill_memory/" + ), + "skills_local_dir": os.getenv( + "SKILLS_LOCAL_DIR", "/tmp/upload_skill_memory/" + ), }, }, }, diff --git a/src/memos/api/server_api.py b/src/memos/api/server_api.py index ac9ed8d88..529a709a4 100644 --- a/src/memos/api/server_api.py +++ b/src/memos/api/server_api.py @@ -1,13 +1,18 @@ import logging +import os +from dotenv import load_dotenv from fastapi import FastAPI, HTTPException from fastapi.exceptions import RequestValidationError +from starlette.staticfiles import StaticFiles from memos.api.exceptions import APIExceptionHandler from memos.api.middleware.request_context import RequestContextMiddleware from memos.api.routers.server_router import router as server_router +load_dotenv() + # Configure logging logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") logger = logging.getLogger(__name__) @@ -18,6 +23,8 @@ version="1.0.1", ) +app.mount("/download", StaticFiles(directory=os.getenv("FILE_LOCAL_PATH")), name="static_mapping") + app.add_middleware(RequestContextMiddleware, source="server_api") # Include routers app.include_router(server_router) diff --git a/src/memos/mem_reader/read_skill_memory/process_skill_memory.py b/src/memos/mem_reader/read_skill_memory/process_skill_memory.py index 3ac45c99a..791f0ca4b 100644 --- a/src/memos/mem_reader/read_skill_memory/process_skill_memory.py +++ b/src/memos/mem_reader/read_skill_memory/process_skill_memory.py @@ -3,6 +3,7 @@ import os import shutil import uuid +import warnings import zipfile from concurrent.futures import as_completed @@ -10,6 +11,8 @@ from pathlib import Path from typing import TYPE_CHECKING, Any +from dotenv import load_dotenv + from memos.context.context import ContextThreadPoolExecutor from memos.dependency import require_python_package from memos.embedders.base import BaseEmbedder @@ -36,6 +39,8 @@ from memos.types import MessageList +load_dotenv() + if TYPE_CHECKING: from memos.types.general_types import UserContext @@ -653,42 +658,88 @@ def _rewrite_query(task_type: str, messages: MessageList, llm: BaseLLM, rewrite_ import_name="alibabacloud_oss_v2", install_command="pip install alibabacloud-oss-v2", ) -def _upload_skills_to_oss(local_file_path: str, oss_file_path: str, client: Any) -> str: - import alibabacloud_oss_v2 as oss - - result = client.put_object_from_file( - request=oss.PutObjectRequest( - bucket=os.getenv("OSS_BUCKET_NAME"), - key=oss_file_path, - ), - filepath=local_file_path, - ) +def _upload_skills( + skills_repo_backend: str, + skills_oss_dir: dict[str, Any] | None, + local_tmp_file_path: str, + local_save_file_path: str, + client: Any, + user_id: str, +) -> str: + if skills_repo_backend == "OSS": + zip_filename = Path(local_tmp_file_path).name + oss_path = (Path(skills_oss_dir) / user_id / zip_filename).as_posix() + + import alibabacloud_oss_v2 as oss + + result = client.put_object_from_file( + request=oss.PutObjectRequest( + bucket=os.getenv("OSS_BUCKET_NAME"), + key=oss_path, + ), + filepath=local_tmp_file_path, + ) - if result.status_code != 200: - logger.warning("[PROCESS_SKILLS] Failed to upload skill to OSS") - return "" + if result.status_code != 200: + logger.warning("[PROCESS_SKILLS] Failed to upload skill to OSS") + return "" + + # Construct and return the URL + bucket_name = os.getenv("OSS_BUCKET_NAME") + endpoint = os.getenv("OSS_ENDPOINT").replace("https://", "").replace("http://", "") + url = f"https://{bucket_name}.{endpoint}/{oss_path}" + return url + else: + import sys + + args = sys.argv + port = ( + int(args[args.index("--port") + 1]) + if "--port" in args and args.index("--port") + 1 < len(args) + else "8000" + ) - # Construct and return the URL - bucket_name = os.getenv("OSS_BUCKET_NAME") - endpoint = os.getenv("OSS_ENDPOINT").replace("https://", "").replace("http://", "") - url = f"https://{bucket_name}.{endpoint}/{oss_file_path}" - return url + zip_path = str(local_tmp_file_path) + os.makedirs(local_save_file_path, exist_ok=True) + file_name = os.path.basename(zip_path) + target_full_path = os.path.join(local_save_file_path, file_name) + shutil.copy2(zip_path, target_full_path) + return f"http://localhost:{port}/download/{file_name}" @require_python_package( import_name="alibabacloud_oss_v2", install_command="pip install alibabacloud-oss-v2", ) -def _delete_skills_from_oss(oss_file_path: str, client: Any) -> Any: - import alibabacloud_oss_v2 as oss - - result = client.delete_object( - oss.DeleteObjectRequest( - bucket=os.getenv("OSS_BUCKET_NAME"), - key=oss_file_path, +def _delete_skills( + skills_repo_backend: str, + zip_filename: str, + client: Any, + skills_oss_dir: dict[str, Any] | None, + local_save_file_path: str, + user_id: str, +) -> Any: + if skills_repo_backend == "OSS": + old_path = (Path(skills_oss_dir) / user_id / zip_filename).as_posix() + import alibabacloud_oss_v2 as oss + + return client.delete_object( + oss.DeleteObjectRequest( + bucket=os.getenv("OSS_BUCKET_NAME"), + key=old_path, + ) ) - ) - return result + else: + target_full_path = os.path.join(local_save_file_path, zip_filename) + target_path = Path(target_full_path) + try: + if target_path.is_file(): + target_path.unlink() + logger.info(f"本地文件 {target_path} 已成功删除") + else: + print(f"本地文件 {target_path} 不存在,无需删除") + except Exception as e: + print(f"删除本地文件时出错:{e}") def _write_skills_to_file( @@ -698,7 +749,7 @@ def _write_skills_to_file( skill_name = skill_memory.get("name", "unnamed_skill").replace(" ", "_").lower() # Create tmp directory for user if it doesn't exist - tmp_dir = Path(skills_dir_config["skills_local_dir"]) / user_id + tmp_dir = Path(skills_dir_config["skills_local_tmp_dir"]) / user_id tmp_dir.mkdir(parents=True, exist_ok=True) # Create skill directory directly in tmp_dir @@ -889,6 +940,54 @@ def create_skill_memory_item( return TextualMemoryItem(id=item_id, memory=memory_content, metadata=metadata) +def _skill_init(skills_repo_backend, oss_config, skills_dir_config): + if skills_repo_backend == "OSS": + # Validate required configurations + if not oss_config: + logger.warning( + "[PROCESS_SKILLS] OSS configuration is required for skill memory processing" + ) + return None, None, False + + if not skills_dir_config: + logger.warning( + "[PROCESS_SKILLS] Skills directory configuration is required for skill memory processing" + ) + return None, None, False + + # Validate skills_dir has required keys + required_keys = ["skills_local_tmp_dir", "skills_local_dir", "skills_oss_dir"] + missing_keys = [key for key in required_keys if key not in skills_dir_config] + if missing_keys: + logger.warning( + f"[PROCESS_SKILLS] Skills directory configuration missing required keys: {', '.join(missing_keys)}" + ) + return None, None, False + + oss_client = create_oss_client(oss_config) + if not oss_client: + logger.warning("[PROCESS_SKILLS] Failed to create OSS client") + return None, None, False + return oss_client, missing_keys, True + else: + return None, None, True + + +def _get_skill_file_storage_location() -> str: + # SKILLS_REPO_BACKEND: Skill 文件保存地址 OSS/LOCAL + allowed_backends = {"OSS", "LOCAL"} + raw_backend = os.getenv("SKILLS_REPO_BACKEND") + if raw_backend in allowed_backends: + return raw_backend + else: + warnings.warn( + "环境变量【SKILLS_REPO_BACKEND】赋值错误,本次使用 LOCAL 存储 skill", + UserWarning, + stacklevel=1, + ) + return "LOCAL" + + def process_skill_memory_fine( fast_memory_items: list[TextualMemoryItem], info: dict[str, Any], @@ -902,15 +1001,9 @@ def process_skill_memory_fine( complete_skill_memory: bool = True, **kwargs, ) -> list[TextualMemoryItem]: - # Validate required configurations - if not oss_config: - logger.warning("[PROCESS_SKILLS] OSS configuration is required for skill memory processing") - return [] - - if not skills_dir_config: - logger.warning( - "[PROCESS_SKILLS] Skills directory configuration is required for skill memory processing" - ) + skills_repo_backend = _get_skill_file_storage_location() + oss_client, missing_keys, flag = _skill_init(skills_repo_backend, oss_config, skills_dir_config) + if not flag: return [] chat_history = kwargs.get("chat_history") @@ -918,20 +1011,6 @@ def process_skill_memory_fine( chat_history = [] logger.warning("[PROCESS_SKILLS] History is None in Skills") - # Validate skills_dir has required keys - required_keys = ["skills_local_dir", "skills_oss_dir"] - missing_keys = [key for key in required_keys if key not in skills_dir_config] - if missing_keys: - logger.warning( - f"[PROCESS_SKILLS] Skills directory configuration missing required keys: {', '.join(missing_keys)}" - ) - return [] - - oss_client = create_oss_client(oss_config) - if not oss_client: - logger.warning("[PROCESS_SKILLS] Failed to create OSS client") - return [] - messages = _reconstruct_messages_from_memory_items(fast_memory_items) chat_history, messages = _preprocess_extract_messages(chat_history, messages) @@ -1060,23 +1139,27 @@ def _full_extract(): old_memory = old_memories_map.get(old_memory_id) if old_memory: - # Get old OSS path from the old memory's metadata - old_oss_path = getattr(old_memory.metadata, "url", None) + # Get old path from the old memory's metadata + old_path = getattr(old_memory.metadata, "url", None) - if old_oss_path: + if old_path: try: # delete old skill from OSS - zip_filename = Path(old_oss_path).name - old_oss_path = ( - Path(skills_dir_config["skills_oss_dir"]) / user_id / zip_filename - ).as_posix() - _delete_skills_from_oss(old_oss_path, oss_client) + zip_filename = Path(old_path).name + _delete_skills( + skills_repo_backend=skills_repo_backend, + zip_filename=zip_filename, + client=oss_client, + skills_oss_dir=skills_dir_config["skills_oss_dir"], + local_save_file_path=skills_dir_config["skills_local_dir"], + user_id=user_id, + ) logger.info( - f"[PROCESS_SKILLS] Deleted old skill from OSS: {old_oss_path}" + f"[PROCESS_SKILLS] Deleted old skill from {skills_repo_backend}: {old_path}" ) except Exception as e: logger.warning( - f"[PROCESS_SKILLS] Failed to delete old skill from OSS: {e}" + f"[PROCESS_SKILLS] Failed to delete old skill from {skills_repo_backend}: {e}" ) # delete old skill from graph db @@ -1086,24 +1169,23 @@ def _full_extract(): f"[PROCESS_SKILLS] Deleted old skill from graph db: {old_memory_id}" ) - # Upload new skill to OSS + # Upload new skill # Use the same filename as the local zip file - zip_filename = Path(zip_path).name - oss_path = ( - Path(skills_dir_config["skills_oss_dir"]) / user_id / zip_filename - ).as_posix() - - # _upload_skills_to_oss returns the URL - url = _upload_skills_to_oss( - local_file_path=str(zip_path), oss_file_path=oss_path, client=oss_client + url = _upload_skills( + skills_repo_backend=skills_repo_backend, + skills_oss_dir=skills_dir_config["skills_oss_dir"], + local_tmp_file_path=zip_path, + local_save_file_path=skills_dir_config["skills_local_dir"], + client=oss_client, + user_id=user_id, ) # Set URL directly to skill_memory skill_memory["url"] = url - logger.info(f"[PROCESS_SKILLS] Uploaded skill to OSS: {url}") + logger.info(f"[PROCESS_SKILLS] Uploaded skill to {skills_repo_backend}: {url}") except Exception as e: - logger.warning(f"[PROCESS_SKILLS] Error uploading skill to OSS: {e}") + logger.warning(f"[PROCESS_SKILLS] Error uploading skill to {skills_repo_backend}: {e}") skill_memory["url"] = "" # Set to empty string if upload fails finally: # Clean up local files after upload