Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ jobs:
run: uv build

- name: Run tests for extractor ${{ matrix.extractor }}
run: uv run pytest -s tests/test_feature_extractors.py -k "${{ matrix.extractor }}" --verbose
run: uv run --extra cpu pytest -s tests/test_feature_extractors.py -k "${{ matrix.extractor }}" --verbose
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}

Expand Down Expand Up @@ -83,7 +83,7 @@ jobs:
run: uv build

- name: Run other tests
run: uv run pytest -s tests/ -k "not test_feature_extractors.py" --verbose
run: uv run --extra cpu pytest -s tests/ -k "not test_feature_extractors.py" --verbose
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}

Expand Down
46 changes: 27 additions & 19 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,30 +75,15 @@ uv sync --extra cpu
source .venv/bin/activate
```

> [!CAUTION]
> In the next step we will build [flash-attn](https://github.com/dao-ailab/flash-attention), this might take an extended amount of time and consume a lot of RAM and CPU time!
>
> Please make sure you have [Nvidia CUDA Toolkit 13.0](https://developer.nvidia.com/cuda-13-0-2-download-archive) installed! You must use Nvidia Driver version 580 or newer!
>
> The `nvcc --version` command must indicate that 13.0 is installed and is currently in PATH: `Cuda compilation tools, release 13.0, V13.0.88`.
>
> If you get another version or `Command 'nvcc' not found`, add it to the PATH:
> ```bash
> export CUDA_HOME=/usr/local/cuda-13.0
> export PATH="${CUDA_HOME}/bin:$PATH"
> ```
>
> Run `nvcc --version` to ensure flash-attn will be built for CUDA 13.0

For the full GPU stack (`conchv1_5`, `gigapath`, `musk`), install with the prebuilt flash-attn wheel — no compile required. Supported on Linux x86_64, Linux aarch64, or Windows x86_64, with Python 3.13, CUDA 13.0, and torch 2.10. Wheels are hosted on the [STAMP releases](https://github.com/KatherLab/STAMP/releases) page.

```bash
# GPU (CUDA) Installation - building flash-attn for supporting conchv1_5, gigapath and musk

MAX_JOBS=2 uv sync --extra gpu_all # to speed up the build time increase max_jobs! This might use more RAM!
# GPU (CUDA) Installation - prebuilt flash-attn wheel, no compile
uv sync --extra gpu_prebuilt
source .venv/bin/activate
```

If you encounter errors during installation please read Installation Troubleshooting [below](#installation-troubleshooting).
If you encounter errors during installation please read Installation Troubleshooting [below](#installation-troubleshooting). If the prebuilt wheel does not fit your platform or you need a different flash-attn version, see [Advanced: Build flash-attn from source](#advanced-build-flash-attn-from-source).

### Additional Dependencies

Expand All @@ -116,6 +101,29 @@ If you encounter errors during installation please read Installation Troubleshoo
> apt update && apt install -y libgl1 libglx-mesa0 libglib2.0-0
> ```

### Advanced: Build flash-attn from source

> [!CAUTION]
> Building flash-attn can take an extended amount of time and consume a lot of RAM and CPU time!
>
> You must have [Nvidia CUDA Toolkit 13.0](https://developer.nvidia.com/cuda-13-0-2-download-archive) installed and Nvidia Driver version 580 or newer.
>
> The `nvcc --version` command must indicate that 13.0 is installed and is currently in PATH: `Cuda compilation tools, release 13.0, V13.0.88`.
>
> If you get another version or `Command 'nvcc' not found`, add it to the PATH:
> ```bash
> export CUDA_HOME=/usr/local/cuda-13.0
> export PATH="${CUDA_HOME}/bin:$PATH"
> ```
>
> Run `nvcc --version` to ensure flash-attn will be built for CUDA 13.0.

```bash
# GPU (CUDA) Installation - building flash-attn for supporting conchv1_5, gigapath and musk
MAX_JOBS=2 uv sync --extra gpu_all # to speed up the build time increase max_jobs! This might use more RAM!
source .venv/bin/activate
```

## Basic Usage

If the installation was successful, running `stamp` in your terminal should yield the following output:
Expand Down
22 changes: 18 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ build = [
]
flash-attention = [
"flash-attn==2.8.3",
]
]
conch = [
"huggingface-hub>=0.26.2",
"conch @ git+https://github.com/KatherLab/CONCH",
Expand Down Expand Up @@ -150,6 +150,14 @@ gpu_all = [
"torchvision~=0.25.0",
"stamp[conch,ctranspath,uni,virchow2,chief_ctranspath,musk,gigapath,conch1_5,prism,madeleine,plip,cobra]"
]
gpu_prebuilt = [
"torch~=2.10.0",
"torchvision~=0.25.0",
"stamp[conch,ctranspath,uni,virchow2,chief_ctranspath,musk,gigapath,conch1_5,prism,madeleine,plip,cobra]",
"flash-attn @ https://github.com/KatherLab/STAMP/releases/download/2.5.0/flash_attn-2.8.3+cu130torch2.10-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl ; sys_platform == 'linux' and platform_machine == 'x86_64' and python_version == '3.13'",
"flash-attn @ https://github.com/KatherLab/STAMP/releases/download/2.5.0/flash_attn-2.8.3+cu130torch2.10-cp313-cp313-manylinux_2_34_aarch64.whl ; sys_platform == 'linux' and platform_machine == 'aarch64' and python_version == '3.13'",
"flash-attn @ https://github.com/KatherLab/STAMP/releases/download/2.5.0/flash_attn-2.8.3+cu130torch2.10-cp313-cp313-win_amd64.whl ; sys_platform == 'win32' and platform_machine == 'AMD64' and python_version == '3.13'",
]
all = ["stamp[cpu]"]

[project.scripts]
Expand Down Expand Up @@ -194,6 +202,10 @@ conflicts = [
{ extra = "cpu" },
{ extra = "gpu_all" },
],
[
{ extra = "cpu" },
{ extra = "gpu_prebuilt" },
],
]
build-constraint-dependencies = [
"torch==2.10.0",
Expand All @@ -206,11 +218,13 @@ torch = [
{ index = "pytorch-cpu", extra = "cpu" },
{ index = "pytorch-cu130", extra = "gpu" },
{ index = "pytorch-cu130", extra = "gpu_all" },
{ index = "pytorch-cu130", extra = "gpu_prebuilt" },
]
torchvision = [
{ index = "pytorch-cpu", extra = "cpu" },
{ index = "pytorch-cu130", extra = "gpu" },
{ index = "pytorch-cu130", extra = "gpu_all" },
{ index = "pytorch-cu130", extra = "gpu_prebuilt" },
]

[[tool.uv.index]]
Expand Down Expand Up @@ -248,6 +262,6 @@ requires-dist = [
]

[tool.uv.extra-build-dependencies]
flash-attn = [{ requirement = "torch", match-runtime = true }]
gigapath = [{ requirement = "torch", match-runtime = true }]
conch = [{ requirement = "torch", match-runtime = true }]
flash-attn = ["torch"]
gigapath = ["torch"]
conch = ["torch"]
3 changes: 2 additions & 1 deletion src/stamp/heatmaps/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@
import logging
from collections.abc import Collection, Iterable
from pathlib import Path
from typing import List, Optional, Tuple, cast
from typing import cast

import h5py
import matplotlib.pyplot as plt
import numpy as np
import openslide
import torch
from beartype.typing import List, Optional, Tuple
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from matplotlib.patches import Patch
Expand Down
8 changes: 7 additions & 1 deletion src/stamp/modeling/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,13 @@ def patient_feature_dataloader(
)
one_hot = torch.tensor(raw_ground_truths.reshape(-1, 1) == categories)
ds = PatientFeatureDataset(feature_files, one_hot, transform=transform)
dl = DataLoader(ds, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)
dl = DataLoader(
ds,
batch_size=batch_size,
shuffle=shuffle,
num_workers=num_workers,
persistent_workers=(num_workers > 0),
)
return dl, categories


Expand Down
4 changes: 2 additions & 2 deletions src/stamp/statistics/prc.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,8 +178,8 @@ def plot_multiple_decorated_precision_recall_curves(
)

# limit conf bounds to [0,1] in case of low sample numbers
lower = max(0, lower)
upper = min(1, upper)
lower = float(max(0.0, lower))
upper = float(min(1.0, upper))

auc_str = f"PRC = {np.mean(aucs):0.2f} [{lower:0.2f}-{upper:0.2f}]"

Expand Down
2 changes: 1 addition & 1 deletion tests/test_deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def test_to_prediction_df(task: str) -> None:
use_alibi=False,
total_steps=1000,
max_lr=1e-4,
div_factor=25,
div_factor=25.0,
)
if task == "classification":
preds_df = _to_prediction_df(
Expand Down
12 changes: 10 additions & 2 deletions tests/test_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,14 @@
}


def _is_gated_repo_error(error: BaseException) -> bool:
while error is not None:
if isinstance(error, GatedRepoError) or "gated repo" in str(error):
return True
error = error.__cause__ or error.__context__
return False


@pytest.mark.slow
@pytest.mark.parametrize("encoder", EncoderName)
@pytest.mark.filterwarnings("ignore:Importing from timm.models.layers is deprecated")
Expand Down Expand Up @@ -177,8 +185,8 @@ def test_if_encoding_crashes(*, tmp_path: Path, encoder: EncoderName):
pytest.skip(f"dependencies for {encoder} not installed")
except GatedRepoError:
pytest.skip(f"cannot access gated repo for {encoder}")
except OSError as e:
if "gated repo" in str(e):
except (OSError, ValueError) as e:
if _is_gated_repo_error(e):
pytest.skip(f"cannot access gated repo for {encoder}")
raise

Expand Down
Loading
Loading