Skip to content
Open

Logs #130

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 6 additions & 17 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,19 +55,6 @@ curl -LsSf https://astral.sh/uv/install.sh | sh
uv self update
```

### Install STAMP in a Virtual Environment:

```bash
uv venv --python=3.12
source .venv/bin/activate

# For a GPU (CUDA) installation:
uv pip install "git+https://github.com/KatherLab/STAMP.git[gpu]"

# For a CPU-only installation:
uv pip install "git+https://github.com/KatherLab/STAMP.git[cpu]" --torch-backend=cpu
```

### Install STAMP from the Repository:

```bash
Expand Down Expand Up @@ -210,11 +197,13 @@ uv cache clean causal_conv1d

# Now it should re-build the packages with the correct torch version

# With uv pip install
uv pip install "git+https://github.com/KatherLab/STAMP.git[build]"
uv pip install "git+https://github.com/KatherLab/STAMP.git[build,gpu] --no-build-isolation"

# With uv sync in the cloned repository
uv sync --extra build
uv sync --extra build --extra gpu
```

## Reproducibility

We use a central `Seed` utility to set seeds for PyTorch, NumPy, and Python’s `random`. This makes data loading and model initialization reproducible. Always call `Seed.set(seed)` once at startup.

We do not enable [`torch.use_deterministic_algorithms()`](https://pytorch.org/docs/stable/notes/randomness.html#reproducibility) because it can cause large performance drops. Expect runs with the same seed to follow the same training trajectory, but not bit-for-bit identical low-level kernels.
24 changes: 16 additions & 8 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ build = [
"ninja"
]
flash-attention = [
"flash-attn>=2.8.2",
"flash-attn>=2.8.3",
]
conch = [
"huggingface-hub>=0.26.2",
Expand Down Expand Up @@ -99,12 +99,13 @@ virchow2 = [
]
cobra = [
"stamp[flash-attention]",
"causal-conv1d @ git+https://github.com/KatherLab/causal-conv1d.git@b73d1ca0e0726ba6520c38d342bd411bb5850064",
"mamba-ssm @ git+https://github.com/KatherLab/mamba.git@d0d4192621889b26f9669ea4a8e6fe79cc84e8d9",
"cobra @ git+http://github.com/KatherLab/COBRA.git@c8aa71ce691e1279f2bb797f536355d6be47b6ac",
"causal-conv1d",
"mamba-ssm",
# "causal-conv1d @ git+https://github.com/KatherLab/causal-conv1d.git@dededae18d0258ccec833ab950d45279f1616fd1",
# "mamba-ssm @ git+https://github.com/KatherLab/mamba.git@3dad301098b721ee5c93d9ad16aafbbc1dc42cfd",
"cobra @ git+http://github.com/KatherLab/COBRA.git@73712e9ffa4d1bdecf9be9826d66094bd2b17534",
"jinja2>=3.1.4",
"triton"
# "triton==3.2.0", # Fix triton to 3.2.0 (also makes torch==2.6.0) until this is solved: https://github.com/pytorch/pytorch/issues/153737
]
prism_cpu = [
"sacremoses==0.1.1",
Expand Down Expand Up @@ -179,6 +180,16 @@ conflicts = [
]
]

[tool.uv.sources]
torch = { index = "pytorch-cu128" }
torchvision = { index = "pytorch-cu128" }

[[tool.uv.index]]
name = "pytorch-cu128"
url = "https://download.pytorch.org/whl/cu128"
explicit = true


[[tool.uv.dependency-metadata]]
name = "uni"
version = "v0.1.0"
Expand All @@ -197,7 +208,6 @@ requires-dist = [

[[tool.uv.dependency-metadata]]
name = "flash-attn"
version = "2.8.2"
requires-dist = [
"torch",
"einops",
Expand All @@ -206,14 +216,12 @@ requires-dist = [

[[tool.uv.dependency-metadata]]
name = "mamba-ssm"
version = "v2.2.4"
requires-dist = [
"setuptools",
]

[[tool.uv.dependency-metadata]]
name = "causal-conv1d"
version = "v1.5.0.post8"
requires-dist = [
"setuptools",
]
Expand Down
57 changes: 41 additions & 16 deletions src/stamp/__main__.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
ModelParams,
VitModelParams,
)
from stamp.seed import Seed

STAMP_FACTORY_SETTINGS = Path(__file__).with_name("config.yaml")

Expand Down Expand Up @@ -49,6 +50,16 @@ def _run_cli(args: argparse.Namespace) -> None:
with open(args.config_file_path, "r") as config_yaml:
config = StampConfig.model_validate(yaml.safe_load(config_yaml))

# use default advanced config in case none is provided
if config.advanced_config is None:
config.advanced_config = AdvancedConfig(
model_params=ModelParams(vit=VitModelParams(), mlp=MlpModelParams())
)

# Set global random seed
if config.advanced_config.seed is not None:
Seed.set(config.advanced_config.seed)

match args.command:
case "init":
raise RuntimeError("this case should be handled above")
Expand All @@ -67,7 +78,8 @@ def _run_cli(args: argparse.Namespace) -> None:
"using the following configuration:\n"
f"{yaml.dump(config.preprocessing.model_dump(mode='json'))}"
)
extract_(
_logger.info("Starting preprocessing...")
summary = extract_(
output_dir=config.preprocessing.output_dir,
wsi_dir=config.preprocessing.wsi_dir,
wsi_list=config.preprocessing.wsi_list,
Expand All @@ -83,6 +95,11 @@ def _run_cli(args: argparse.Namespace) -> None:
cache_tiles_ext=config.preprocessing.cache_tiles_ext,
generate_hash=config.preprocessing.generate_hash,
)
_logger.info("preprocessing finished.")
_logger.info(
f"Slides processed: {summary['processed']}, "
f"failed: {summary['failed']}, skipped: {summary['skipped']}"
)

case "encode_slides":
from stamp.encoding import init_slide_encoder_
Expand All @@ -95,14 +112,20 @@ def _run_cli(args: argparse.Namespace) -> None:
"using the following configuration:\n"
f"{yaml.dump(config.slide_encoding.model_dump(mode='json'))}"
)
init_slide_encoder_(
_logger.info("Starting slide encoding...")
summary = init_slide_encoder_(
encoder=config.slide_encoding.encoder,
output_dir=config.slide_encoding.output_dir,
feat_dir=config.slide_encoding.feat_dir,
device=config.slide_encoding.device,
agg_feat_dir=config.slide_encoding.agg_feat_dir,
generate_hash=config.slide_encoding.generate_hash,
)
_logger.info("slide encoding finished.")
_logger.info(
f"Slides processed: {summary['processed']}, "
f"failed: {summary['failed']}, skipped: {summary['skipped']}"
)

case "encode_patients":
from stamp.encoding import init_patient_encoder_
Expand All @@ -115,7 +138,8 @@ def _run_cli(args: argparse.Namespace) -> None:
"using the following configuration:\n"
f"{yaml.dump(config.patient_encoding.model_dump(mode='json'))}"
)
init_patient_encoder_(
_logger.info("Starting patient encoding...")
summary = init_patient_encoder_(
encoder=config.patient_encoding.encoder,
output_dir=config.patient_encoding.output_dir,
feat_dir=config.patient_encoding.feat_dir,
Expand All @@ -126,28 +150,29 @@ def _run_cli(args: argparse.Namespace) -> None:
agg_feat_dir=config.patient_encoding.agg_feat_dir,
generate_hash=config.patient_encoding.generate_hash,
)
_logger.info("patient encoding finished.")
_logger.info(
f"Patients processed: {summary['processed']}, "
f"failed: {summary['failed']}, skipped: {summary['skipped']}"
)

case "train":
from stamp.modeling.train import train_categorical_model_

if config.training is None:
raise ValueError("no training configuration supplied")

# use default advanced config in case none is provided
if config.advanced_config is None:
config.advanced_config = AdvancedConfig(
model_params=ModelParams(vit=VitModelParams(), mlp=MlpModelParams())
)

_add_file_handle_(_logger, output_dir=config.training.output_dir)
_logger.info(
"using the following configuration:\n"
f"{yaml.dump(config.training.model_dump(mode='json'))}"
)
_logger.info("Starting training...")

train_categorical_model_(
config=config.training, advanced=config.advanced_config
)
_logger.info("Training finished.")

case "deploy":
from stamp.modeling.deploy import deploy_categorical_model_
Expand All @@ -160,6 +185,7 @@ def _run_cli(args: argparse.Namespace) -> None:
"using the following configuration:\n"
f"{yaml.dump(config.deployment.model_dump(mode='json'))}"
)
_logger.info("Starting deployment...")
deploy_categorical_model_(
output_dir=config.deployment.output_dir,
checkpoint_paths=config.deployment.checkpoint_paths,
Expand All @@ -172,6 +198,7 @@ def _run_cli(args: argparse.Namespace) -> None:
num_workers=config.deployment.num_workers,
accelerator=config.deployment.accelerator,
)
_logger.info("Deployment finished...")

case "crossval":
from stamp.modeling.crossval import categorical_crossval_
Expand All @@ -184,17 +211,13 @@ def _run_cli(args: argparse.Namespace) -> None:
"using the following configuration:\n"
f"{yaml.dump(config.crossval.model_dump(mode='json'))}"
)

# use default advanced config in case none is provided
if config.advanced_config is None:
config.advanced_config = AdvancedConfig(
model_params=ModelParams(vit=VitModelParams(), mlp=MlpModelParams())
)


_logger.info("Starting crossval...")
categorical_crossval_(
config=config.crossval,
advanced=config.advanced_config,
)
_logger.info("Crossval finished...")

case "statistics":
from stamp.statistics import compute_stats_
Expand Down Expand Up @@ -225,6 +248,7 @@ def _run_cli(args: argparse.Namespace) -> None:
"using the following configuration:\n"
f"{yaml.dump(config.heatmaps.model_dump(mode='json'))}"
)
_logger.info("Starting heatmap generation...")
heatmaps_(
feature_dir=config.heatmaps.feature_dir,
wsi_dir=config.heatmaps.wsi_dir,
Expand All @@ -237,6 +261,7 @@ def _run_cli(args: argparse.Namespace) -> None:
default_slide_mpp=config.heatmaps.default_slide_mpp,
opacity=config.heatmaps.opacity,
)
_logger.info("Heatmaps finished...")

case _:
raise RuntimeError(
Expand Down
2 changes: 2 additions & 0 deletions src/stamp/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,8 @@ patient_encoding:


advanced_config:
# Optional random seed
# seed: 42
max_epochs: 32
patience: 16
batch_size: 64
Expand Down
53 changes: 35 additions & 18 deletions src/stamp/encoding/encoder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,12 @@ def encode_slides_(
device: DeviceLikeType,
generate_hash: bool,
**kwargs,
) -> None:
) -> dict:
"""General method for encoding slide-level features. Called by init_slide_encoder_.
Override this function if coords are required. See init_slide_encoder_ for full description"""
processed = 0
failed = 0
skipped = 0
# generate the name for the folder containing the feats
if generate_hash:
encode_dir = f"{self.identifier}-slide-{get_processing_code_hash(Path(__file__))[:8]}"
Expand All @@ -71,12 +74,14 @@ def encode_slides_(
_logger.info(
f"skipping {str(slide_name)} because {output_path} already exists"
)
skipped += 1
continue

try:
feats, coords = self._validate_and_read_features(h5_path)
except ValueError as e:
tqdm.write(s=str(e))
failed += 1
continue

slide_embedding = self._generate_slide_embedding(
Expand All @@ -85,6 +90,8 @@ def encode_slides_(
self._save_features_(
output_path=output_path, feats=slide_embedding, feat_type="slide"
)
processed += 1
return {"processed": processed, "failed": failed, "skipped": skipped}

def encode_patients_(
self,
Expand All @@ -96,9 +103,12 @@ def encode_patients_(
device: DeviceLikeType,
generate_hash: bool,
**kwargs,
) -> None:
) -> dict:
"""General method for encoding patient-level features. Called by init_patient_encoder_.
Override this function if coords are required. See init_patient_encoder_ for full description"""
processed = 0
failed = 0
skipped = 0
# generate the name for the folder containing the feats
if generate_hash:
encode_dir = (
Expand Down Expand Up @@ -126,27 +136,34 @@ def encode_patients_(
_logger.info(
f"skipping {str(patient_id)} because {output_path} already exists"
)
skipped += 1
continue

feats_list = []

for _, row in group.iterrows():
slide_filename = row[filename_label]
h5_path = os.path.join(feat_dir, slide_filename)
feats, _ = self._validate_and_read_features(h5_path)
feats_list.append(feats)

if not feats_list:
tqdm.write(f"No features found for patient {patient_id}, skipping.")
try:
for _, row in group.iterrows():
slide_filename = row[filename_label]
h5_path = os.path.join(feat_dir, slide_filename)
feats, _ = self._validate_and_read_features(h5_path)
feats_list.append(feats)

if not feats_list:
tqdm.write(f"No features found for patient {patient_id}, skipping.")
skipped += 1
continue

patient_embedding = self._generate_patient_embedding(
feats_list, device, **kwargs
)
self._save_features_(
output_path=output_path, feats=patient_embedding, feat_type="patient"
)
processed += 1
except Exception as e:
tqdm.write(f"Failed to process patient {patient_id}: {e}")
failed += 1
continue

patient_embedding = self._generate_patient_embedding(
feats_list, device, **kwargs
)
self._save_features_(
output_path=output_path, feats=patient_embedding, feat_type="patient"
)

@abstractmethod
def _generate_slide_embedding(
self, feats: torch.Tensor, device, **kwargs
Expand Down
Loading