Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/memos/graph_dbs/neo4j.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,10 @@ def add_node(
# Flatten info fields to top level (for Neo4j flat structure)
metadata = _flatten_info_fields(metadata)

# Initialize delete_time and delete_record_id fields
metadata.setdefault("delete_time", "")
metadata.setdefault("delete_record_id", "")

# Merge node and set metadata
created_at = metadata.pop("created_at")
updated_at = metadata.pop("updated_at")
Expand Down Expand Up @@ -251,6 +255,7 @@ def add_nodes_batch(
- metadata: dict[str, Any] - Node metadata
user_name: Optional user name (will use config default if not provided)
"""
logger.info("neo4j [add_nodes_batch] staring")
if not nodes:
logger.warning("[add_nodes_batch] Empty nodes list, skipping")
return
Expand Down Expand Up @@ -280,6 +285,10 @@ def add_nodes_batch(
# Flatten info fields to top level (for Neo4j flat structure)
metadata = _flatten_info_fields(metadata)

# Initialize delete_time and delete_record_id fields
metadata.setdefault("delete_time", "")
metadata.setdefault("delete_record_id", "")

# Merge node and set metadata
created_at = metadata.pop("created_at")
updated_at = metadata.pop("updated_at")
Expand Down
9 changes: 9 additions & 0 deletions src/memos/graph_dbs/neo4j_community.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ def add_node(
# Safely process metadata
metadata = _prepare_node_metadata(metadata)

# Initialize delete_time and delete_record_id fields
metadata.setdefault("delete_time", "")
metadata.setdefault("delete_record_id", "")

# serialization
if metadata["sources"]:
for idx in range(len(metadata["sources"])):
Expand Down Expand Up @@ -105,6 +109,7 @@ def add_node(
)

def add_nodes_batch(self, nodes: list[dict[str, Any]], user_name: str | None = None) -> None:
print("neo4j_community add_nodes_batch:")
if not nodes:
logger.warning("[add_nodes_batch] Empty nodes list, skipping")
return
Expand All @@ -130,6 +135,10 @@ def add_nodes_batch(self, nodes: list[dict[str, Any]], user_name: str | None = N
metadata = _prepare_node_metadata(metadata)
metadata = _flatten_info_fields(metadata)

# Initialize delete_time and delete_record_id fields
metadata.setdefault("delete_time", "")
metadata.setdefault("delete_record_id", "")

embedding = metadata.pop("embedding", None)
if embedding is None:
raise ValueError(f"Missing 'embedding' in metadata for node {node_id}")
Expand Down
42 changes: 37 additions & 5 deletions src/memos/graph_dbs/polardb.py
Original file line number Diff line number Diff line change
Expand Up @@ -835,6 +835,11 @@ def create_edge(self):
def add_edge(
self, source_id: str, target_id: str, type: str, user_name: str | None = None
) -> None:
logger.info(
f"polardb [add_edge] source_id: {source_id}, target_id: {target_id}, type: {type},user_name:{user_name}"
)

start_time = time.time()
if not source_id or not target_id:
logger.warning(f"Edge '{source_id}' and '{target_id}' are both None")
raise ValueError("[add_edge] source_id and target_id must be provided")
Expand Down Expand Up @@ -864,13 +869,16 @@ def add_edge(
AND end_id = ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, '{target_id}'::text::cstring)
);
"""

logger.info(f"polardb [add_edge] query: {query}, properties: {json.dumps(properties)}")
conn = None
try:
conn = self._get_connection()
with conn.cursor() as cursor:
cursor.execute(query, (source_id, target_id, type, json.dumps(properties)))
logger.info(f"Edge created: {source_id} -[{type}]-> {target_id}")

elapsed_time = time.time() - start_time
logger.info(f" polardb [add_edge] insert completed time in {elapsed_time:.2f}s")
except Exception as e:
logger.error(f"Failed to insert edge: {e}", exc_info=True)
raise
Expand Down Expand Up @@ -1033,7 +1041,10 @@ def get_node(
Returns:
dict: Node properties as key-value pairs, or None if not found.
"""

logger.info(
f"polardb [get_node] id: {id}, include_embedding: {include_embedding}, user_name: {user_name}"
)
start_time = time.time()
select_fields = "id, properties, embedding" if include_embedding else "id, properties"

query = f"""
Expand All @@ -1048,6 +1059,7 @@ def get_node(
query += "\nAND ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = %s::agtype"
params.append(self.format_param_value(user_name))

logger.info(f"polardb [get_node] query: {query},params: {params}")
conn = None
try:
conn = self._get_connection()
Expand Down Expand Up @@ -1084,6 +1096,10 @@ def get_node(
except (json.JSONDecodeError, TypeError):
logger.warning(f"Failed to parse embedding for node {id}")

elapsed_time = time.time() - start_time
logger.info(
f" polardb [get_node] get_node completed time in {elapsed_time:.2f}s"
)
return self._parse_node(
{
"id": id,
Expand Down Expand Up @@ -1879,7 +1895,7 @@ def search_by_fulltext(
filter: dict | None = None,
knowledgebase_ids: list[str] | None = None,
tsvector_field: str = "properties_tsvector_zh",
tsquery_config: str = "jiebaqry",
tsquery_config: str = "jiebacfg",
**kwargs,
) -> list[dict]:
"""
Expand All @@ -1902,7 +1918,11 @@ def search_by_fulltext(
Returns:
list[dict]: result list containing id and score
"""
logger.info(
f"[search_by_fulltext] query_words: {query_words},top_k:{top_k},scope:{scope},status:{status},threshold:{threshold},search_filter:{search_filter},user_name:{user_name},knowledgebase_ids:{knowledgebase_ids},filter:{filter}"
)
# Build WHERE clause dynamically, same as search_by_embedding
start_time = time.time()
where_clauses = []

if scope:
Expand All @@ -1924,6 +1944,7 @@ def search_by_fulltext(
knowledgebase_ids=knowledgebase_ids,
default_user_name=self.config.user_name,
)
logger.info(f"[search_by_fulltext] user_name_conditions: {user_name_conditions}")

# Add OR condition if we have any user_name conditions
if user_name_conditions:
Expand All @@ -1946,6 +1967,8 @@ def search_by_fulltext(

# Build filter conditions using common method
filter_conditions = self._build_filter_conditions_sql(filter)
logger.info(f"[search_by_fulltext] filter_conditions: {filter_conditions}")

where_clauses.extend(filter_conditions)
# Add fulltext search condition
# Convert query_text to OR query format: "word1 | word2 | word3"
Expand All @@ -1955,6 +1978,8 @@ def search_by_fulltext(

where_clause = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else ""

logger.info(f"[search_by_fulltext] where_clause: {where_clause}")

# Build fulltext search query
query = f"""
SELECT
Expand Down Expand Up @@ -1986,7 +2011,10 @@ def search_by_fulltext(
# Apply threshold filter if specified
if threshold is None or score_val >= threshold:
output.append({"id": id_val, "score": score_val})

elapsed_time = time.time() - start_time
logger.info(
f" polardb [search_by_fulltext] query completed time in {elapsed_time:.2f}s"
)
return output[:top_k]
finally:
self._return_connection(conn)
Expand Down Expand Up @@ -3394,6 +3422,8 @@ def add_node(
"memory": memory,
"created_at": created_at,
"updated_at": updated_at,
"delete_time": "",
"delete_record_id": "",
**metadata,
}

Expand Down Expand Up @@ -3535,6 +3565,8 @@ def add_nodes_batch(
"memory": memory,
"created_at": created_at,
"updated_at": updated_at,
"delete_time": "",
"delete_record_id": "",
**metadata,
}

Expand Down Expand Up @@ -4281,7 +4313,7 @@ def _build_user_name_and_kb_ids_conditions_sql(
user_name_conditions = []
effective_user_name = user_name if user_name else default_user_name

if effective_user_name:
if effective_user_name and default_user_name != "xxx":
user_name_conditions.append(
f"ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = '\"{effective_user_name}\"'::agtype"
)
Expand Down