Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
167 changes: 155 additions & 12 deletions src/cosdata/api/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ def dense(
Returns:
Search results
"""
url = f"{self.collection.client.base_url}/collections/{self.collection.name}/search/dense"
url = f"{self.collection.client.base_url}/collections/{
self.collection.name
}/search/dense"
data = {
"query_vector": query_vector,
"top_k": top_k,
Expand Down Expand Up @@ -68,7 +70,9 @@ def batch_dense(
Returns:
List of search results
"""
url = f"{self.collection.client.base_url}/collections/{self.collection.name}/search/batch-dense"
url = f"{self.collection.client.base_url}/collections/{
self.collection.name
}/search/batch-dense"

# Validate that each query has a "vector" field
for i, query in enumerate(queries):
Expand Down Expand Up @@ -108,7 +112,9 @@ def sparse(
Returns:
Search results
"""
url = f"{self.collection.client.base_url}/collections/{self.collection.name}/search/sparse"
url = f"{self.collection.client.base_url}/collections/{
self.collection.name
}/search/sparse"
data = {
"query_terms": query_terms,
"top_k": top_k,
Expand Down Expand Up @@ -147,7 +153,9 @@ def batch_sparse(
Returns:
List of search results
"""
url = f"{self.collection.client.base_url}/collections/{self.collection.name}/search/batch-sparse"
url = f"{self.collection.client.base_url}/collections/{
self.collection.name
}/search/batch-sparse"
data = {
"query_terms_list": query_terms_list,
"top_k": top_k,
Expand Down Expand Up @@ -181,7 +189,9 @@ def text(
Returns:
Search results
"""
url = f"{self.collection.client.base_url}/collections/{self.collection.name}/search/tf-idf"
url = f"{self.collection.client.base_url}/collections/{
self.collection.name
}/search/tf-idf"
data = {"query": query_text, "top_k": top_k, "return_raw_text": return_raw_text}

response = requests.post(
Expand Down Expand Up @@ -210,7 +220,9 @@ def batch_text(
Returns:
List of search results
"""
url = f"{self.collection.client.base_url}/collections/{self.collection.name}/search/batch-tf-idf"
url = f"{self.collection.client.base_url}/collections/{
self.collection.name
}/search/batch-tf-idf"
data = {
"queries": query_texts, # Changed from "query_texts" to "queries"
"top_k": top_k,
Expand All @@ -229,26 +241,158 @@ def batch_text(

return response.json()

def hybrid_search(self, queries):
def hybrid(
self,
query_vector: List[float],
query_terms: List[Dict[str, Union[int, float]]],
query_text: str,
top_k: int = 10,
fusion_constant_k: float = 60.0,
return_raw_text: bool = False,
) -> Dict[str, Any]:
"""
Perform a hybrid search on this collection.
Perform hybrid search combining dense, sparse, and TF-IDF.

Args:
query_vector: Query vector for dense search
query_terms: Query terms for sparse search
query_text: Query text for TF-IDF search
top_k: Number of results to return
fusion_constant_k: Fusion constant for RRF
return_raw_text: Whether to return raw text

Returns:
Hybrid search results
"""
url = f"{self.collection.client.base_url}/collections/{self.collection.name}/search/hybrid"
url = f"{self.collection.client.base_url}/collections/{
self.collection.name
}/search/hybrid"
data = {
"query": {
"dense": {
"query_vector": query_vector,
"top_k": top_k,
},
"sparse": {
"query_terms": query_terms,
"top_k": top_k,
},
"tfidf": {
"query_text": query_text,
"top_k": top_k,
},
},
"fusion_constant_k": fusion_constant_k,
"return_raw_text": return_raw_text,
}

response = requests.post(
url,
headers=self.collection.client._get_headers(),
data=json.dumps(queries),
data=json.dumps(data),
verify=self.collection.client.verify_ssl,
)

if response.status_code != 200:
raise Exception(f"Failed to perform hybrid search: {response.text}")

return response.json()

def batch_hybrid(
self,
queries: List[Dict[str, Any]],
top_k: int = 10,
fusion_constant_k: float = 60.0,
return_raw_text: bool = False,
) -> Dict[str, Any]:
"""
Perform batch hybrid search.

Args:
queries: List of hybrid query dictionaries with dense, sparse, and tfidf components
top_k: Maximum number of results to return per query
fusion_constant_k: Reciprocal rank fusion constant
return_raw_text: Whether to include raw text in the response

Returns:
Batch hybrid search results
"""
url = f"{self.collection.client.base_url}/collections/{
self.collection.name
}/search/batch-hybrid"
data = {
"queries": queries,
"top_k": top_k,
"fusion_constant_k": fusion_constant_k,
"return_raw_text": return_raw_text,
}

response = requests.post(
url,
headers=self.collection.client._get_headers(),
data=json.dumps(data),
verify=self.collection.client.verify_ssl,
)

if response.status_code != 200:
raise Exception(f"Failed to perform batch hybrid search: {response.text}")

return response.json()

def batch_hybrid_parallel(
self,
queries: List[Dict[str, Any]],
top_k: int = 10,
fusion_constant_k: float = 60.0,
return_raw_text: bool = False,
) -> Dict[str, Any]:
"""
Perform batch hybrid search using parallel processing on this collection.

This endpoint processes each query individually using the hybrid_search function,
which can be useful for debugging or when you want to ensure each query is
processed independently.

Args:
queries: List of hybrid query objects
top_k: Maximum number of results to return per query
fusion_constant_k: Reciprocal rank fusion constant
return_raw_text: Whether to include raw text in the response

Returns:
Batch search results
"""
url = f"{self.collection.client.base_url}/collections/{
self.collection.name
}/search/batch-hybrid-parallel"
data = {
"queries": queries,
"top_k": top_k,
"fusion_constant_k": fusion_constant_k,
"return_raw_text": return_raw_text,
}

response = requests.post(
url,
headers=self.collection.client._get_headers(),
data=json.dumps(data),
verify=self.collection.client.verify_ssl,
)

if response.status_code != 200:
raise Exception(
f"Failed to perform batch hybrid search parallel: {response.text}"
)

return response.json()

def batch_tf_idf_search(self, queries, top_k=10, return_raw_text=False):
"""
Perform batch tf-idf search on this collection.
"""
url = f"{self.collection.client.base_url}/collections/{self.collection.name}/search/batch-tf-idf"
url = f"{self.collection.client.base_url}/collections/{
self.collection.name
}/search/batch-tf-idf"
data = {"queries": queries, "top_k": top_k, "return_raw_text": return_raw_text}
response = requests.post(
url,
Expand All @@ -259,4 +403,3 @@ def batch_tf_idf_search(self, queries, top_k=10, return_raw_text=False):
if response.status_code != 200:
raise Exception(f"Failed to perform batch tf-idf search: {response.text}")
return response.json()