Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions src/memos/llms/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ def generate(self, messages: MessageList, **kwargs) -> str:
f"Request body: {request_body}, Response from OpenAI: {response.model_dump_json()}, Cost time: {cost_time}"
)

if not response.choices:
logger.warning("OpenAI response has no choices")
return ""

tool_calls = getattr(response.choices[0].message, "tool_calls", None)
if isinstance(tool_calls, list) and len(tool_calls) > 0:
return self.tool_call_parser(tool_calls)
Expand Down Expand Up @@ -99,6 +103,8 @@ def generate_stream(self, messages: MessageList, **kwargs) -> Generator[str, Non
reasoning_started = False

for chunk in response:
if not chunk.choices:
continue
delta = chunk.choices[0].delta

# Support for custom 'reasoning_content' (if present in OpenAI-compatible models like Qwen, DeepSeek)
Expand Down Expand Up @@ -153,6 +159,10 @@ def generate(self, messages: MessageList, **kwargs) -> str:
extra_body=kwargs.get("extra_body", self.config.extra_body),
)
logger.info(f"Response from Azure OpenAI: {response.model_dump_json()}")
if not response.choices:
logger.warning("Azure OpenAI response has no choices")
return ""

if response.choices[0].message.tool_calls:
return self.tool_call_parser(response.choices[0].message.tool_calls)
response_content = response.choices[0].message.content
Expand Down Expand Up @@ -180,6 +190,8 @@ def generate_stream(self, messages: MessageList, **kwargs) -> Generator[str, Non
reasoning_started = False

for chunk in response:
if not chunk.choices:
continue
delta = chunk.choices[0].delta

# Support for custom 'reasoning_content' (if present in OpenAI-compatible models like Qwen, DeepSeek)
Expand Down
6 changes: 6 additions & 0 deletions src/memos/llms/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,10 @@ def _generate_with_api_client(self, messages: list[MessageDict], **kwargs) -> st

response = self.client.chat.completions.create(**completion_kwargs)

if not response.choices:
logger.warning("VLLM response has no choices")
return ""

if response.choices[0].message.tool_calls:
return self.tool_call_parser(response.choices[0].message.tool_calls)

Expand Down Expand Up @@ -184,6 +188,8 @@ def generate_stream(self, messages: list[MessageDict], **kwargs):

reasoning_started = False
for chunk in stream:
if not chunk.choices:
continue
delta = chunk.choices[0].delta
if hasattr(delta, "reasoning") and delta.reasoning:
if not reasoning_started and not self.config.remove_think_prefix:
Expand Down
48 changes: 3 additions & 45 deletions src/memos/memories/textual/tree_text_memory/retrieve/searcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,51 +290,6 @@ def _parse_task(

return parsed_goal, query_embedding, context, query

@timed
def _retrieve_simple(
self,
query: str,
top_k: int,
search_filter: dict | None = None,
user_name: str | None = None,
**kwargs,
):
"""Retrieve from by keywords and embedding"""
query_words = []
if self.tokenizer:
query_words = self.tokenizer.tokenize_mixed(query)
else:
query_words = query.strip().split()
query_words = [query, *query_words]
logger.info(f"[SIMPLESEARCH] Query words: {query_words}")
query_embeddings = self.embedder.embed(query_words)

items = self.graph_retriever.retrieve_from_mixed(
top_k=top_k * 2,
memory_scope=None,
query_embedding=query_embeddings,
search_filter=search_filter,
user_name=user_name,
use_fast_graph=self.use_fast_graph,
)
logger.info(f"[SIMPLESEARCH] Items count: {len(items)}")
documents = [getattr(item, "memory", "") for item in items]
if not documents:
return []
documents_embeddings = self.embedder.embed(documents)
similarity_matrix = cosine_similarity_matrix(documents_embeddings)
selected_indices, _ = find_best_unrelated_subgroup(documents, similarity_matrix)
selected_items = [items[i] for i in selected_indices]
logger.info(
f"[SIMPLESEARCH] after unrelated subgroup selection items count: {len(selected_items)}"
)
return self.reranker.rerank(
query=query,
query_embedding=query_embeddings[0],
graph_results=selected_items,
top_k=top_k,
)

@timed
def _retrieve_paths(
self,
Expand Down Expand Up @@ -722,6 +677,9 @@ def _retrieve_simple(
if not documents:
return []
documents_embeddings = self.embedder.embed(documents)
if not documents_embeddings:
logger.info("[SIMPLESEARCH] Documents embeddings is empty")
return []
similarity_matrix = cosine_similarity_matrix(documents_embeddings)
selected_indices, _ = find_best_unrelated_subgroup(documents, similarity_matrix)
selected_items = [items[i] for i in selected_indices]
Expand Down