From d0fd03ba3e3963e82bd7e8ac6c1aeb6457e77e11 Mon Sep 17 00:00:00 2001 From: swappy <59965507+rycerzes@users.noreply.github.com> Date: Tue, 12 Aug 2025 22:18:12 +0530 Subject: [PATCH] feat(hybrid-search): add hybrid search & batch hybrid search methods --- src/cosdata/api/search.py | 167 +++++++++++++++++++++++++++++++++++--- 1 file changed, 155 insertions(+), 12 deletions(-) diff --git a/src/cosdata/api/search.py b/src/cosdata/api/search.py index 0f0c7c7..15be128 100644 --- a/src/cosdata/api/search.py +++ b/src/cosdata/api/search.py @@ -32,7 +32,9 @@ def dense( Returns: Search results """ - url = f"{self.collection.client.base_url}/collections/{self.collection.name}/search/dense" + url = f"{self.collection.client.base_url}/collections/{ + self.collection.name + }/search/dense" data = { "query_vector": query_vector, "top_k": top_k, @@ -68,7 +70,9 @@ def batch_dense( Returns: List of search results """ - url = f"{self.collection.client.base_url}/collections/{self.collection.name}/search/batch-dense" + url = f"{self.collection.client.base_url}/collections/{ + self.collection.name + }/search/batch-dense" # Validate that each query has a "vector" field for i, query in enumerate(queries): @@ -108,7 +112,9 @@ def sparse( Returns: Search results """ - url = f"{self.collection.client.base_url}/collections/{self.collection.name}/search/sparse" + url = f"{self.collection.client.base_url}/collections/{ + self.collection.name + }/search/sparse" data = { "query_terms": query_terms, "top_k": top_k, @@ -147,7 +153,9 @@ def batch_sparse( Returns: List of search results """ - url = f"{self.collection.client.base_url}/collections/{self.collection.name}/search/batch-sparse" + url = f"{self.collection.client.base_url}/collections/{ + self.collection.name + }/search/batch-sparse" data = { "query_terms_list": query_terms_list, "top_k": top_k, @@ -181,7 +189,9 @@ def text( Returns: Search results """ - url = f"{self.collection.client.base_url}/collections/{self.collection.name}/search/tf-idf" + url = f"{self.collection.client.base_url}/collections/{ + self.collection.name + }/search/tf-idf" data = {"query": query_text, "top_k": top_k, "return_raw_text": return_raw_text} response = requests.post( @@ -210,7 +220,9 @@ def batch_text( Returns: List of search results """ - url = f"{self.collection.client.base_url}/collections/{self.collection.name}/search/batch-tf-idf" + url = f"{self.collection.client.base_url}/collections/{ + self.collection.name + }/search/batch-tf-idf" data = { "queries": query_texts, # Changed from "query_texts" to "queries" "top_k": top_k, @@ -229,26 +241,158 @@ def batch_text( return response.json() - def hybrid_search(self, queries): + def hybrid( + self, + query_vector: List[float], + query_terms: List[Dict[str, Union[int, float]]], + query_text: str, + top_k: int = 10, + fusion_constant_k: float = 60.0, + return_raw_text: bool = False, + ) -> Dict[str, Any]: """ - Perform a hybrid search on this collection. + Perform hybrid search combining dense, sparse, and TF-IDF. + + Args: + query_vector: Query vector for dense search + query_terms: Query terms for sparse search + query_text: Query text for TF-IDF search + top_k: Number of results to return + fusion_constant_k: Fusion constant for RRF + return_raw_text: Whether to return raw text + + Returns: + Hybrid search results """ - url = f"{self.collection.client.base_url}/collections/{self.collection.name}/search/hybrid" + url = f"{self.collection.client.base_url}/collections/{ + self.collection.name + }/search/hybrid" + data = { + "query": { + "dense": { + "query_vector": query_vector, + "top_k": top_k, + }, + "sparse": { + "query_terms": query_terms, + "top_k": top_k, + }, + "tfidf": { + "query_text": query_text, + "top_k": top_k, + }, + }, + "fusion_constant_k": fusion_constant_k, + "return_raw_text": return_raw_text, + } + response = requests.post( url, headers=self.collection.client._get_headers(), - data=json.dumps(queries), + data=json.dumps(data), verify=self.collection.client.verify_ssl, ) + if response.status_code != 200: raise Exception(f"Failed to perform hybrid search: {response.text}") + + return response.json() + + def batch_hybrid( + self, + queries: List[Dict[str, Any]], + top_k: int = 10, + fusion_constant_k: float = 60.0, + return_raw_text: bool = False, + ) -> Dict[str, Any]: + """ + Perform batch hybrid search. + + Args: + queries: List of hybrid query dictionaries with dense, sparse, and tfidf components + top_k: Maximum number of results to return per query + fusion_constant_k: Reciprocal rank fusion constant + return_raw_text: Whether to include raw text in the response + + Returns: + Batch hybrid search results + """ + url = f"{self.collection.client.base_url}/collections/{ + self.collection.name + }/search/batch-hybrid" + data = { + "queries": queries, + "top_k": top_k, + "fusion_constant_k": fusion_constant_k, + "return_raw_text": return_raw_text, + } + + response = requests.post( + url, + headers=self.collection.client._get_headers(), + data=json.dumps(data), + verify=self.collection.client.verify_ssl, + ) + + if response.status_code != 200: + raise Exception(f"Failed to perform batch hybrid search: {response.text}") + + return response.json() + + def batch_hybrid_parallel( + self, + queries: List[Dict[str, Any]], + top_k: int = 10, + fusion_constant_k: float = 60.0, + return_raw_text: bool = False, + ) -> Dict[str, Any]: + """ + Perform batch hybrid search using parallel processing on this collection. + + This endpoint processes each query individually using the hybrid_search function, + which can be useful for debugging or when you want to ensure each query is + processed independently. + + Args: + queries: List of hybrid query objects + top_k: Maximum number of results to return per query + fusion_constant_k: Reciprocal rank fusion constant + return_raw_text: Whether to include raw text in the response + + Returns: + Batch search results + """ + url = f"{self.collection.client.base_url}/collections/{ + self.collection.name + }/search/batch-hybrid-parallel" + data = { + "queries": queries, + "top_k": top_k, + "fusion_constant_k": fusion_constant_k, + "return_raw_text": return_raw_text, + } + + response = requests.post( + url, + headers=self.collection.client._get_headers(), + data=json.dumps(data), + verify=self.collection.client.verify_ssl, + ) + + if response.status_code != 200: + raise Exception( + f"Failed to perform batch hybrid search parallel: {response.text}" + ) + return response.json() def batch_tf_idf_search(self, queries, top_k=10, return_raw_text=False): """ Perform batch tf-idf search on this collection. """ - url = f"{self.collection.client.base_url}/collections/{self.collection.name}/search/batch-tf-idf" + url = f"{self.collection.client.base_url}/collections/{ + self.collection.name + }/search/batch-tf-idf" data = {"queries": queries, "top_k": top_k, "return_raw_text": return_raw_text} response = requests.post( url, @@ -259,4 +403,3 @@ def batch_tf_idf_search(self, queries, top_k=10, return_raw_text=False): if response.status_code != 200: raise Exception(f"Failed to perform batch tf-idf search: {response.text}") return response.json() -