Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cytetype/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .main import CyteType

__all__ = ["CyteType"]
__version__ = "0.9.2"
__version__ = "0.10.0"
2 changes: 2 additions & 0 deletions cytetype/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,8 @@ def poll_for_results(

# Extract cluster status for all cases
raw_response = status_response.get("raw_response", {})
if not raw_response:
raise CyteTypeAPIError("No response from API")
current_cluster_status = raw_response.get("clusterStatus", {})

if status == "completed":
Expand Down
29 changes: 24 additions & 5 deletions cytetype/main.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
from typing import Any
from typing import Any, cast
import json

import anndata
import pandas as pd
from natsort import natsorted
from pydantic import ValidationError


from .config import logger, DEFAULT_API_URL, DEFAULT_POLL_INTERVAL, DEFAULT_TIMEOUT
from .client import submit_job, poll_for_results, check_job_status
from .server_schema import LLMModelConfig
from .server_schema import LLMModelConfig, InputData
from .anndata_helpers import (
_validate_adata,
_calculate_pcent,
Expand Down Expand Up @@ -245,6 +246,7 @@ def run(
study_context: str,
llm_configs: list[dict[str, Any]] | None = None,
metadata: dict[str, Any] | None = None,
n_parallel_clusters: int = 2,
results_prefix: str = "cytetype",
poll_interval_seconds: int = DEFAULT_POLL_INTERVAL,
timeout_seconds: int = DEFAULT_TIMEOUT,
Expand All @@ -267,6 +269,8 @@ def run(
metadata (dict[str, Any] | None, optional): Custom metadata tags to include in the report header.
Values that look like URLs will be made clickable in the report.
Defaults to None.
n_parallel_clusters (int, optional): Number of parallel requests to make to the model. Maximum is 50. Note than high values can lead to rate limit errors.
Defaults to 2.
results_prefix (str, optional): Prefix for keys added to `adata.obs` and `adata.uns` to
store results. The final annotation column will be
`adata.obs[f"{results_key}_{group_key}"]`. Defaults to "cytetype".
Expand Down Expand Up @@ -306,8 +310,25 @@ def run(
"markerGenes": self.marker_genes,
"visualizationData": self.visualization_data,
"expressionData": self.expression_percentages,
"nParallelClusters": n_parallel_clusters,
}

try:
validated_input = InputData(**cast(dict[str, Any], input_data))
input_data = validated_input.model_dump()
except ValidationError as e:
logger.error(f"Validation error: {e}")
raise e

if llm_configs:
try:
llm_configs = [LLMModelConfig(**x).model_dump() for x in llm_configs]
except ValidationError as e:
logger.error(f"Validation error: {e}")
raise e
else:
llm_configs = []

if save_query:
with open(query_filename, "w") as f:
json.dump(input_data, f)
Expand All @@ -316,9 +337,7 @@ def run(
job_id = submit_job(
{
"input_data": input_data,
"llm_configs": [LLMModelConfig(**x).model_dump() for x in llm_configs]
if llm_configs
else None,
"llm_configs": llm_configs if llm_configs else None,
},
api_url,
auth_token=auth_token,
Expand Down
7 changes: 7 additions & 0 deletions cytetype/server_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,12 @@ class InputData(BaseModel):
default_factory=dict,
description="Dictionary mapping gene names to their expression percentages across clusters",
)
nParallelClusters: int = Field(
default=2,
ge=1,
le=50,
description="Number of parallel requests to make to the model",
)

@classmethod
def get_example(cls) -> "InputData":
Expand Down Expand Up @@ -151,4 +157,5 @@ def get_example(cls) -> "InputData":
"Cluster3": 5.2,
},
},
nParallelClusters=5,
)
18 changes: 9 additions & 9 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ description = "Python client for characterization of clusters from single-cell R
readme = "README.md"
requires-python = ">=3.11"
dependencies = [
"anndata~=0.11.4",
"anndata>=0.9.2",
"loguru~=0.7.3",
"natsort~=8.4.0",
"requests~=2.32.3",
"requests~=2.32.5",
"pydantic~=2.11.7",
"session-info~=1.0.1",
"python-dotenv~=1.1.1",
Expand All @@ -26,12 +26,12 @@ Repository = "https://github.com/NygenAnalytics/cytetype"

[dependency-groups]
dev = [
"jupyterlab>=4.4.2",
"mypy>=1.15.0",
"pytest>=8.3.5",
"ruff>=0.11.7",
"scanpy>=1.11.1",
"types-requests>=2.32.0.20250328",
"jupyterlab>=4.4.6",
"mypy>=1.17.1",
"pytest>=8.4.1",
"ruff>=0.12.11",
"scanpy>=1.11.4",
"types-requests>=2.32.4.20250809",
]

[build-system]
Expand All @@ -47,7 +47,7 @@ version = {attr = "cytetype.__version__"}
[tool.mypy]
strict = false

python_version = "3.12"
python_version = "3.11"

disable_error_code = ["import-untyped"]

Expand Down
Loading