Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions src/memos/api/handlers/chat_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,8 @@ def handle_chat_complete(self, chat_req: APIChatCompleteRequest) -> dict[str, An
query=chat_req.query,
full_response=response,
async_mode="async",
manager_user_id=chat_req.manager_user_id,
project_id=chat_req.project_id,
)
end = time.time()
self.logger.info(f"[Cloud Service] Chat Add Time: {end - start} seconds")
Expand Down Expand Up @@ -382,6 +384,8 @@ def generate_chat_response() -> Generator[str, None, None]:
query=chat_req.query,
full_response=full_response,
async_mode="async",
manager_user_id=chat_req.manager_user_id,
project_id=chat_req.project_id,
)
end = time.time()
self.logger.info(
Expand Down Expand Up @@ -563,6 +567,8 @@ def generate_chat_response() -> Generator[str, None, None]:
query=chat_req.query,
full_response=None,
async_mode="sync",
manager_user_id=chat_req.manager_user_id,
project_id=chat_req.project_id,
)

# Extract memories from search results (second search)
Expand Down Expand Up @@ -731,6 +737,8 @@ def generate_chat_response() -> Generator[str, None, None]:
query=chat_req.query,
full_response=full_response,
async_mode="sync",
manager_user_id=chat_req.manager_user_id,
project_id=chat_req.project_id,
)

except Exception as e:
Expand Down Expand Up @@ -917,6 +925,8 @@ def generate_chat_response() -> Generator[str, None, None]:
query=chat_req.query,
full_response=full_response,
async_mode="async",
manager_user_id=chat_req.manager_user_id,
project_id=chat_req.project_id,
)
end = time.time()
self.logger.info(
Expand Down Expand Up @@ -1309,6 +1319,8 @@ async def _add_conversation_to_memory(
writable_cube_ids: list[str],
session_id: str,
query: str,
manager_user_id: str | None = None,
project_id: str | None = None,
clean_response: str | None = None,
async_mode: Literal["async", "sync"] = "sync",
) -> None:
Expand All @@ -1333,6 +1345,8 @@ async def _add_conversation_to_memory(
session_id=session_id,
messages=messages,
async_mode=async_mode,
manager_user_id=manager_user_id,
project_id=project_id,
)

self.add_handler.handle_add_memories(add_req)
Expand Down Expand Up @@ -1540,6 +1554,8 @@ def _start_add_to_memory(
query: str,
full_response: str | None = None,
async_mode: Literal["async", "sync"] = "sync",
manager_user_id: str | None = None,
project_id: str | None = None,
) -> None:
def run_async_in_thread():
try:
Expand All @@ -1557,6 +1573,8 @@ def run_async_in_thread():
query=query,
clean_response=clean_response,
async_mode=async_mode,
manager_user_id=manager_user_id,
project_id=project_id,
)
)
finally:
Expand All @@ -1580,6 +1598,8 @@ def run_async_in_thread():
query=query,
clean_response=clean_response,
async_mode=async_mode,
manager_user_id=manager_user_id,
project_id=project_id,
)
)
task.add_done_callback(
Expand Down
93 changes: 75 additions & 18 deletions src/memos/api/handlers/memory_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,13 +421,27 @@ def handle_get_memories_dashboard(
filter=get_mem_req.filter,
memory_type=text_memory_type,
)
text_memories, total_text_nodes = text_memories_info["nodes"], text_memories_info["total_nodes"]
text_memories, _ = text_memories_info["nodes"], text_memories_info["total_nodes"]

# Group text memories by cube_id from metadata.user_name
text_mem_by_cube: dict[str, list] = {}
for memory in text_memories:
cube_id = memory.get("metadata", {}).get("user_name", get_mem_req.mem_cube_id)
if cube_id not in text_mem_by_cube:
text_mem_by_cube[cube_id] = []
text_mem_by_cube[cube_id].append(memory)

# If no memories found, create a default entry with the requested cube_id
if not text_mem_by_cube and get_mem_req.mem_cube_id:
text_mem_by_cube[get_mem_req.mem_cube_id] = []

results["text_mem"] = [
{
"cube_id": get_mem_req.mem_cube_id,
"memories": text_memories,
"total_nodes": total_text_nodes,
"cube_id": cube_id,
"memories": memories,
"total_nodes": len(memories),
}
for cube_id, memories in text_mem_by_cube.items()
]

if get_mem_req.include_tool_memory:
Expand All @@ -439,18 +453,32 @@ def handle_get_memories_dashboard(
filter=get_mem_req.filter,
memory_type=["ToolSchemaMemory", "ToolTrajectoryMemory"],
)
tool_memories, total_tool_nodes = (
tool_memories, _ = (
tool_memories_info["nodes"],
tool_memories_info["total_nodes"],
)

# Group tool memories by cube_id from metadata.user_name
tool_mem_by_cube: dict[str, list] = {}
for memory in tool_memories:
cube_id = memory.get("metadata", {}).get("user_name", get_mem_req.mem_cube_id)
if cube_id not in tool_mem_by_cube:
tool_mem_by_cube[cube_id] = []
tool_mem_by_cube[cube_id].append(memory)

# If no memories found, create a default entry with the requested cube_id
if not tool_mem_by_cube and get_mem_req.mem_cube_id:
tool_mem_by_cube[get_mem_req.mem_cube_id] = []

results["tool_mem"] = [
{
"cube_id": get_mem_req.mem_cube_id,
"memories": tool_memories,
"total_nodes": total_tool_nodes,
"cube_id": cube_id,
"memories": memories,
"total_nodes": len(memories),
}
for cube_id, memories in tool_mem_by_cube.items()
]

if get_mem_req.include_skill_memory:
skill_memories_info = naive_mem_cube.text_mem.get_all(
user_name=get_mem_req.mem_cube_id,
Expand All @@ -460,18 +488,32 @@ def handle_get_memories_dashboard(
filter=get_mem_req.filter,
memory_type=["SkillMemory"],
)
skill_memories, total_skill_nodes = (
skill_memories, _ = (
skill_memories_info["nodes"],
skill_memories_info["total_nodes"],
)

# Group skill memories by cube_id from metadata.user_name
skill_mem_by_cube: dict[str, list] = {}
for memory in skill_memories:
cube_id = memory.get("metadata", {}).get("user_name", get_mem_req.mem_cube_id)
if cube_id not in skill_mem_by_cube:
skill_mem_by_cube[cube_id] = []
skill_mem_by_cube[cube_id].append(memory)

# If no memories found, create a default entry with the requested cube_id
if not skill_mem_by_cube and get_mem_req.mem_cube_id:
skill_mem_by_cube[get_mem_req.mem_cube_id] = []

results["skill_mem"] = [
{
"cube_id": get_mem_req.mem_cube_id,
"memories": skill_memories,
"total_nodes": total_skill_nodes,
"cube_id": cube_id,
"memories": memories,
"total_nodes": len(memories),
}
for cube_id, memories in skill_mem_by_cube.items()
]

preferences: list[TextualMemoryItem] = []
total_preference_nodes = 0

Expand Down Expand Up @@ -507,13 +549,28 @@ def handle_get_memories_dashboard(
)
format_preferences = [format_memory_item(item, save_sources=False) for item in preferences]

results = post_process_pref_mem(
results, format_preferences, get_mem_req.mem_cube_id, get_mem_req.include_preference
)
if total_preference_nodes > 0 and results.get("pref_mem", []):
results["pref_mem"][0]["total_nodes"] = total_preference_nodes
# Group preferences by cube_id from metadata.mem_cube_id
pref_mem_by_cube: dict[str, list] = {}
for pref in format_preferences:
cube_id = pref.get("metadata", {}).get("mem_cube_id", get_mem_req.mem_cube_id)
if cube_id not in pref_mem_by_cube:
pref_mem_by_cube[cube_id] = []
pref_mem_by_cube[cube_id].append(pref)

# Filter to only keep text_mem, pref_mem, tool_mem
# If no preferences found, create a default entry with the requested cube_id
if not pref_mem_by_cube and get_mem_req.mem_cube_id:
pref_mem_by_cube[get_mem_req.mem_cube_id] = []

results["pref_mem"] = [
{
"cube_id": cube_id,
"memories": memories,
"total_nodes": len(memories),
}
for cube_id, memories in pref_mem_by_cube.items()
]

# Filter to only keep text_mem, pref_mem, tool_mem, skill_mem
filtered_results = {
"text_mem": results.get("text_mem", []),
"pref_mem": results.get("pref_mem", []),
Expand Down
4 changes: 4 additions & 0 deletions src/memos/api/product_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ class ChatRequest(BaseRequest):
temperature: float | None = Field(None, description="Temperature for sampling")
top_p: float | None = Field(None, description="Top-p (nucleus) sampling parameter")
add_message_on_answer: bool = Field(True, description="Add dialogs to memory after chat")
manager_user_id: str | None = Field(None, description="Manager User ID")
project_id: str | None = Field(None, description="Project ID")

# ==== Filter conditions ====
filter: dict[str, Any] | None = Field(
Expand Down Expand Up @@ -771,6 +773,8 @@ class APIChatCompleteRequest(BaseRequest):
temperature: float | None = Field(None, description="Temperature for sampling")
top_p: float | None = Field(None, description="Top-p (nucleus) sampling parameter")
add_message_on_answer: bool = Field(True, description="Add dialogs to memory after chat")
manager_user_id: str | None = Field(None, description="Manager User ID")
project_id: str | None = Field(None, description="Project ID")

# ==== Filter conditions ====
filter: dict[str, Any] | None = Field(
Expand Down
2 changes: 1 addition & 1 deletion src/memos/graph_dbs/polardb.py
Original file line number Diff line number Diff line change
Expand Up @@ -5650,7 +5650,7 @@ def recover_memory_by_mem_cube_id(
SET properties = (
properties::jsonb || %s::jsonb
)::text::agtype,
deletetime = NULL
deletetime = NULL
WHERE {where_clause}
"""

Expand Down