Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 16 additions & 7 deletions src/memos/mem_scheduler/webservice_modules/rabbitmq_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from memos.mem_scheduler.general_modules.base import BaseSchedulerModule
from memos.mem_scheduler.general_modules.misc import AutoDroppingQueue
from memos.mem_scheduler.schemas.general_schemas import DIRECT_EXCHANGE_TYPE, FANOUT_EXCHANGE_TYPE
from memos.mem_scheduler.utils.misc_utils import is_cloud_env


logger = get_logger(__name__)
Expand Down Expand Up @@ -132,6 +131,15 @@ def initialize_rabbitmq(
self.rabbitmq_exchange_type = self.rabbitmq_config.exchange_type
logger.info(f"Using configured exchange type: {self.rabbitmq_exchange_type}")

env_exchange_name = os.getenv("MEMSCHEDULER_RABBITMQ_EXCHANGE_NAME")
env_exchange_type = os.getenv("MEMSCHEDULER_RABBITMQ_EXCHANGE_TYPE")
if env_exchange_name:
self.rabbitmq_exchange_name = env_exchange_name
logger.info(f"Using env exchange name override: {self.rabbitmq_exchange_name}")
if env_exchange_type:
self.rabbitmq_exchange_type = env_exchange_type
logger.info(f"Using env exchange type override: {self.rabbitmq_exchange_type}")

# Start connection process
parameters = self.get_rabbitmq_connection_param()
self.rabbitmq_connection = SelectConnection(
Expand Down Expand Up @@ -313,15 +321,16 @@ def rabbitmq_publish_message(self, message: dict):
if label == "knowledgeBaseUpdate":
routing_key = ""

# Cloud environment override: applies to specific message types if MEMSCHEDULER_RABBITMQ_EXCHANGE_NAME is set
# Env override: apply to all message types when MEMSCHEDULER_RABBITMQ_EXCHANGE_NAME is set
env_exchange_name = os.getenv("MEMSCHEDULER_RABBITMQ_EXCHANGE_NAME")
if is_cloud_env() and env_exchange_name and label in ["taskStatus", "knowledgeBaseUpdate"]:
env_routing_key = os.getenv("MEMSCHEDULER_RABBITMQ_ROUTING_KEY")
if env_exchange_name:
exchange_name = env_exchange_name
routing_key = "" # Routing key is always empty in cloud environment for these types

# Specific diagnostic logging for messages affected by cloud environment settings
routing_key = (
env_routing_key if env_routing_key is not None and env_routing_key != "" else ""
)
logger.info(
f"[DIAGNOSTIC] Publishing {label} message in Cloud Env. "
f"[DIAGNOSTIC] Publishing {label} message with env exchange override. "
f"Exchange: {exchange_name}, Routing Key: '{routing_key}'."
)
logger.info(f" - Message Content: {json.dumps(message, indent=2, ensure_ascii=False)}")
Expand Down