Skip to content
129 changes: 112 additions & 17 deletions src/memos/api/handlers/memory_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from memos.api.handlers.formatters_handler import (
format_memory_item,
post_process_pref_mem,
post_process_textual_mem,
)
from memos.api.product_models import (
DeleteMemoryRequest,
Expand Down Expand Up @@ -251,22 +250,68 @@ def handle_get_memories(
get_mem_req: GetMemoryRequest, naive_mem_cube: NaiveMemCube
) -> GetMemoryResponse:
results: dict[str, Any] = {"text_mem": [], "pref_mem": [], "tool_mem": [], "skill_mem": []}
memories = naive_mem_cube.text_mem.get_all(
text_memory_type = ["WorkingMemory", "LongTermMemory", "UserMemory", "OuterMemory"]
text_memories_info = naive_mem_cube.text_mem.get_all(
user_name=get_mem_req.mem_cube_id,
user_id=get_mem_req.user_id,
page=get_mem_req.page,
page_size=get_mem_req.page_size,
filter=get_mem_req.filter,
)["nodes"]
memory_type=text_memory_type,
)
text_memories, total_text_nodes = text_memories_info["nodes"], text_memories_info["total_nodes"]
results["text_mem"] = [
{
"cube_id": get_mem_req.mem_cube_id,
"memories": text_memories,
"total_nodes": total_text_nodes,
}
]

results = post_process_textual_mem(results, memories, get_mem_req.mem_cube_id)
if get_mem_req.include_tool_memory:
tool_memories_info = naive_mem_cube.text_mem.get_all(
user_name=get_mem_req.mem_cube_id,
user_id=get_mem_req.user_id,
page=get_mem_req.page,
page_size=get_mem_req.page_size,
filter=get_mem_req.filter,
memory_type=["ToolSchemaMemory", "ToolTrajectoryMemory"],
)
tool_memories, total_tool_nodes = (
tool_memories_info["nodes"],
tool_memories_info["total_nodes"],
)

if not get_mem_req.include_tool_memory:
results["tool_mem"] = []
if not get_mem_req.include_skill_memory:
results["skill_mem"] = []
results["tool_mem"] = [
{
"cube_id": get_mem_req.mem_cube_id,
"memories": tool_memories,
"total_nodes": total_tool_nodes,
}
]
if get_mem_req.include_skill_memory:
skill_memories_info = naive_mem_cube.text_mem.get_all(
user_name=get_mem_req.mem_cube_id,
user_id=get_mem_req.user_id,
page=get_mem_req.page,
page_size=get_mem_req.page_size,
filter=get_mem_req.filter,
memory_type=["SkillMemory"],
)
skill_memories, total_skill_nodes = (
skill_memories_info["nodes"],
skill_memories_info["total_nodes"],
)

results["skill_mem"] = [
{
"cube_id": get_mem_req.mem_cube_id,
"memories": skill_memories,
"total_nodes": total_skill_nodes,
}
]
preferences: list[TextualMemoryItem] = []
total_preference_nodes = 0

format_preferences = []
if get_mem_req.include_preference and naive_mem_cube.pref_mem is not None:
Expand Down Expand Up @@ -295,14 +340,16 @@ def handle_get_memories(

filter_params.update(filter_copy)

preferences, _ = naive_mem_cube.pref_mem.get_memory_by_filter(
preferences, total_preference_nodes = naive_mem_cube.pref_mem.get_memory_by_filter(
filter_params, page=get_mem_req.page, page_size=get_mem_req.page_size
)
format_preferences = [format_memory_item(item, save_sources=False) for item in preferences]

results = post_process_pref_mem(
results, format_preferences, get_mem_req.mem_cube_id, get_mem_req.include_preference
)
if total_preference_nodes > 0 and results.get("pref_mem", []):
results["pref_mem"][0]["total_nodes"] = total_preference_nodes

# Filter to only keep text_mem, pref_mem, tool_mem
filtered_results = {
Expand Down Expand Up @@ -365,22 +412,68 @@ def handle_get_memories_dashboard(
get_mem_req: GetMemoryDashboardRequest, naive_mem_cube: NaiveMemCube
) -> GetMemoryResponse:
results: dict[str, Any] = {"text_mem": [], "pref_mem": [], "tool_mem": [], "skill_mem": []}
memories = naive_mem_cube.text_mem.get_all(
text_memory_type = ["WorkingMemory", "LongTermMemory", "UserMemory", "OuterMemory"]
text_memories_info = naive_mem_cube.text_mem.get_all(
user_name=get_mem_req.mem_cube_id,
user_id=get_mem_req.user_id,
page=get_mem_req.page,
page_size=get_mem_req.page_size,
filter=get_mem_req.filter,
)["nodes"]
memory_type=text_memory_type,
)
text_memories, total_text_nodes = text_memories_info["nodes"], text_memories_info["total_nodes"]
results["text_mem"] = [
{
"cube_id": get_mem_req.mem_cube_id,
"memories": text_memories,
"total_nodes": total_text_nodes,
}
]

results = post_process_textual_mem(results, memories, get_mem_req.mem_cube_id)
if get_mem_req.include_tool_memory:
tool_memories_info = naive_mem_cube.text_mem.get_all(
user_name=get_mem_req.mem_cube_id,
user_id=get_mem_req.user_id,
page=get_mem_req.page,
page_size=get_mem_req.page_size,
filter=get_mem_req.filter,
memory_type=["ToolSchemaMemory", "ToolTrajectoryMemory"],
)
tool_memories, total_tool_nodes = (
tool_memories_info["nodes"],
tool_memories_info["total_nodes"],
)

if not get_mem_req.include_tool_memory:
results["tool_mem"] = []
if not get_mem_req.include_skill_memory:
results["skill_mem"] = []
results["tool_mem"] = [
{
"cube_id": get_mem_req.mem_cube_id,
"memories": tool_memories,
"total_nodes": total_tool_nodes,
}
]
if get_mem_req.include_skill_memory:
skill_memories_info = naive_mem_cube.text_mem.get_all(
user_name=get_mem_req.mem_cube_id,
user_id=get_mem_req.user_id,
page=get_mem_req.page,
page_size=get_mem_req.page_size,
filter=get_mem_req.filter,
memory_type=["SkillMemory"],
)
skill_memories, total_skill_nodes = (
skill_memories_info["nodes"],
skill_memories_info["total_nodes"],
)

results["skill_mem"] = [
{
"cube_id": get_mem_req.mem_cube_id,
"memories": skill_memories,
"total_nodes": total_skill_nodes,
}
]
preferences: list[TextualMemoryItem] = []
total_preference_nodes = 0

format_preferences = []
if get_mem_req.include_preference and naive_mem_cube.pref_mem is not None:
Expand Down Expand Up @@ -409,14 +502,16 @@ def handle_get_memories_dashboard(

filter_params.update(filter_copy)

preferences, _ = naive_mem_cube.pref_mem.get_memory_by_filter(
preferences, total_preference_nodes = naive_mem_cube.pref_mem.get_memory_by_filter(
filter_params, page=get_mem_req.page, page_size=get_mem_req.page_size
)
format_preferences = [format_memory_item(item, save_sources=False) for item in preferences]

results = post_process_pref_mem(
results, format_preferences, get_mem_req.mem_cube_id, get_mem_req.include_preference
)
if total_preference_nodes > 0 and results.get("pref_mem", []):
results["pref_mem"][0]["total_nodes"] = total_preference_nodes

# Filter to only keep text_mem, pref_mem, tool_mem
filtered_results = {
Expand Down
8 changes: 7 additions & 1 deletion src/memos/memories/textual/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,13 +368,19 @@ def get_all(
page: int | None = None,
page_size: int | None = None,
filter: dict | None = None,
memory_type: list[str] | None = None,
) -> dict:
"""Get all memories.
Returns:
list[TextualMemoryItem]: List of all memories.
"""
graph_output = self.graph_store.export_graph(
user_name=user_name, user_id=user_id, page=page, page_size=page_size, filter=filter
user_name=user_name,
user_id=user_id,
page=page,
page_size=page_size,
filter=filter,
memory_type=memory_type,
)
return graph_output

Expand Down