diff --git a/src/memos/llms/openai.py b/src/memos/llms/openai.py index f49f1d7d1..93dac42fb 100644 --- a/src/memos/llms/openai.py +++ b/src/memos/llms/openai.py @@ -57,6 +57,10 @@ def generate(self, messages: MessageList, **kwargs) -> str: f"Request body: {request_body}, Response from OpenAI: {response.model_dump_json()}, Cost time: {cost_time}" ) + if not response.choices: + logger.warning("OpenAI response has no choices") + return "" + tool_calls = getattr(response.choices[0].message, "tool_calls", None) if isinstance(tool_calls, list) and len(tool_calls) > 0: return self.tool_call_parser(tool_calls) @@ -99,6 +103,8 @@ def generate_stream(self, messages: MessageList, **kwargs) -> Generator[str, Non reasoning_started = False for chunk in response: + if not chunk.choices: + continue delta = chunk.choices[0].delta # Support for custom 'reasoning_content' (if present in OpenAI-compatible models like Qwen, DeepSeek) @@ -153,6 +159,10 @@ def generate(self, messages: MessageList, **kwargs) -> str: extra_body=kwargs.get("extra_body", self.config.extra_body), ) logger.info(f"Response from Azure OpenAI: {response.model_dump_json()}") + if not response.choices: + logger.warning("Azure OpenAI response has no choices") + return "" + if response.choices[0].message.tool_calls: return self.tool_call_parser(response.choices[0].message.tool_calls) response_content = response.choices[0].message.content @@ -180,6 +190,8 @@ def generate_stream(self, messages: MessageList, **kwargs) -> Generator[str, Non reasoning_started = False for chunk in response: + if not chunk.choices: + continue delta = chunk.choices[0].delta # Support for custom 'reasoning_content' (if present in OpenAI-compatible models like Qwen, DeepSeek) diff --git a/src/memos/llms/vllm.py b/src/memos/llms/vllm.py index 1cf8d4f39..362112f11 100644 --- a/src/memos/llms/vllm.py +++ b/src/memos/llms/vllm.py @@ -125,6 +125,10 @@ def _generate_with_api_client(self, messages: list[MessageDict], **kwargs) -> st response = self.client.chat.completions.create(**completion_kwargs) + if not response.choices: + logger.warning("VLLM response has no choices") + return "" + if response.choices[0].message.tool_calls: return self.tool_call_parser(response.choices[0].message.tool_calls) @@ -184,6 +188,8 @@ def generate_stream(self, messages: list[MessageDict], **kwargs): reasoning_started = False for chunk in stream: + if not chunk.choices: + continue delta = chunk.choices[0].delta if hasattr(delta, "reasoning") and delta.reasoning: if not reasoning_started and not self.config.remove_think_prefix: diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py index 3612d37eb..8c30d74f3 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py @@ -290,51 +290,6 @@ def _parse_task( return parsed_goal, query_embedding, context, query - @timed - def _retrieve_simple( - self, - query: str, - top_k: int, - search_filter: dict | None = None, - user_name: str | None = None, - **kwargs, - ): - """Retrieve from by keywords and embedding""" - query_words = [] - if self.tokenizer: - query_words = self.tokenizer.tokenize_mixed(query) - else: - query_words = query.strip().split() - query_words = [query, *query_words] - logger.info(f"[SIMPLESEARCH] Query words: {query_words}") - query_embeddings = self.embedder.embed(query_words) - - items = self.graph_retriever.retrieve_from_mixed( - top_k=top_k * 2, - memory_scope=None, - query_embedding=query_embeddings, - search_filter=search_filter, - user_name=user_name, - use_fast_graph=self.use_fast_graph, - ) - logger.info(f"[SIMPLESEARCH] Items count: {len(items)}") - documents = [getattr(item, "memory", "") for item in items] - if not documents: - return [] - documents_embeddings = self.embedder.embed(documents) - similarity_matrix = cosine_similarity_matrix(documents_embeddings) - selected_indices, _ = find_best_unrelated_subgroup(documents, similarity_matrix) - selected_items = [items[i] for i in selected_indices] - logger.info( - f"[SIMPLESEARCH] after unrelated subgroup selection items count: {len(selected_items)}" - ) - return self.reranker.rerank( - query=query, - query_embedding=query_embeddings[0], - graph_results=selected_items, - top_k=top_k, - ) - @timed def _retrieve_paths( self, @@ -722,6 +677,9 @@ def _retrieve_simple( if not documents: return [] documents_embeddings = self.embedder.embed(documents) + if not documents_embeddings: + logger.info("[SIMPLESEARCH] Documents embeddings is empty") + return [] similarity_matrix = cosine_similarity_matrix(documents_embeddings) selected_indices, _ = find_best_unrelated_subgroup(documents, similarity_matrix) selected_items = [items[i] for i in selected_indices]