Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ jobs:
run: uv build

- name: Run tests for extractor ${{ matrix.extractor }}
run: uv run pytest -s tests/test_feature_extractors.py -k "${{ matrix.extractor }}" --verbose
run: uv run --extra cpu pytest -s tests/test_feature_extractors.py -k "${{ matrix.extractor }}" --verbose
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}

Expand Down Expand Up @@ -83,7 +83,7 @@ jobs:
run: uv build

- name: Run other tests
run: uv run pytest -s tests/ -k "not test_feature_extractors.py" --verbose
run: uv run --extra cpu pytest -s tests/ -k "not test_feature_extractors.py" --verbose
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}

Expand Down
13 changes: 12 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,20 @@ source .venv/bin/activate
> Run `nvcc --version` to ensure flash-attn will be built for CUDA 13.0


For the full GPU stack (`conchv1_5`, `gigapath`, `musk`), pick **one** of the two options below:

**Option A — Prebuilt flash-attn wheel (fast, no compile).** Recommended if your environment matches: Linux x86_64, Linux aarch64, or Windows x86_64, with Python 3.13, CUDA 13.0, and torch 2.10. Wheels are hosted on the [STAMP releases](https://github.com/KatherLab/STAMP/releases) page.

```bash
# GPU (CUDA) Installation - building flash-attn for supporting conchv1_5, gigapath and musk
# GPU (CUDA) Installation - prebuilt flash-attn wheel, no compile
uv sync --extra gpu_prebuilt
source .venv/bin/activate
```

**Option B — Build flash-attn from source.** Use this on macOS, or whenever the prebuilt wheel markers do not match your platform. The `nvcc` build can take a long time and use a lot of RAM.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the flash_attn does not run on MacOS at all, right?


```bash
# GPU (CUDA) Installation - building flash-attn for supporting conchv1_5, gigapath and musk
MAX_JOBS=2 uv sync --extra gpu_all # to speed up the build time increase max_jobs! This might use more RAM!
source .venv/bin/activate
```
Expand Down
16 changes: 15 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ build = [
]
flash-attention = [
"flash-attn==2.8.3",
]
]
conch = [
"huggingface-hub>=0.26.2",
"conch @ git+https://github.com/KatherLab/CONCH",
Expand Down Expand Up @@ -150,6 +150,14 @@ gpu_all = [
"torchvision~=0.25.0",
"stamp[conch,ctranspath,uni,virchow2,chief_ctranspath,musk,gigapath,conch1_5,prism,madeleine,plip,cobra]"
]
gpu_prebuilt = [
"torch~=2.10.0",
"torchvision~=0.25.0",
"stamp[conch,ctranspath,uni,virchow2,chief_ctranspath,musk,gigapath,conch1_5,prism,madeleine,plip,cobra]",
"flash-attn @ https://github.com/KatherLab/STAMP/releases/download/2.5.0/flash_attn-2.8.3+cu130torch2.10-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl ; sys_platform == 'linux' and platform_machine == 'x86_64' and python_version == '3.13'",
"flash-attn @ https://github.com/KatherLab/STAMP/releases/download/2.5.0/flash_attn-2.8.3+cu130torch2.10-cp313-cp313-manylinux_2_34_aarch64.whl ; sys_platform == 'linux' and platform_machine == 'aarch64' and python_version == '3.13'",
"flash-attn @ https://github.com/KatherLab/STAMP/releases/download/2.5.0/flash_attn-2.8.3+cu130torch2.10-cp313-cp313-win_amd64.whl ; sys_platform == 'win32' and platform_machine == 'AMD64' and python_version == '3.13'",
]
all = ["stamp[cpu]"]

[project.scripts]
Expand Down Expand Up @@ -194,6 +202,10 @@ conflicts = [
{ extra = "cpu" },
{ extra = "gpu_all" },
],
[
{ extra = "cpu" },
{ extra = "gpu_prebuilt" },
],
]
build-constraint-dependencies = [
"torch==2.10.0",
Expand All @@ -206,11 +218,13 @@ torch = [
{ index = "pytorch-cpu", extra = "cpu" },
{ index = "pytorch-cu130", extra = "gpu" },
{ index = "pytorch-cu130", extra = "gpu_all" },
{ index = "pytorch-cu130", extra = "gpu_prebuilt" },
]
torchvision = [
{ index = "pytorch-cpu", extra = "cpu" },
{ index = "pytorch-cu130", extra = "gpu" },
{ index = "pytorch-cu130", extra = "gpu_all" },
{ index = "pytorch-cu130", extra = "gpu_prebuilt" },
]

[[tool.uv.index]]
Expand Down
3 changes: 2 additions & 1 deletion src/stamp/heatmaps/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@
import logging
from collections.abc import Collection, Iterable
from pathlib import Path
from typing import List, Optional, Tuple, cast
from typing import cast

import h5py
import matplotlib.pyplot as plt
import numpy as np
import openslide
import torch
from beartype.typing import List, Optional, Tuple
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from matplotlib.patches import Patch
Expand Down
8 changes: 7 additions & 1 deletion src/stamp/modeling/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,13 @@ def patient_feature_dataloader(
)
one_hot = torch.tensor(raw_ground_truths.reshape(-1, 1) == categories)
ds = PatientFeatureDataset(feature_files, one_hot, transform=transform)
dl = DataLoader(ds, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)
dl = DataLoader(
ds,
batch_size=batch_size,
shuffle=shuffle,
num_workers=num_workers,
persistent_workers=(num_workers > 0),
)
return dl, categories


Expand Down
2 changes: 1 addition & 1 deletion tests/test_deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def test_to_prediction_df(task: str) -> None:
use_alibi=False,
total_steps=1000,
max_lr=1e-4,
div_factor=25,
div_factor=25.0,
)
if task == "classification":
preds_df = _to_prediction_df(
Expand Down
12 changes: 10 additions & 2 deletions tests/test_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,14 @@
}


def _is_gated_repo_error(error: BaseException) -> bool:
while error is not None:
if isinstance(error, GatedRepoError) or "gated repo" in str(error):
return True
error = error.__cause__ or error.__context__
return False


@pytest.mark.slow
@pytest.mark.parametrize("encoder", EncoderName)
@pytest.mark.filterwarnings("ignore:Importing from timm.models.layers is deprecated")
Expand Down Expand Up @@ -177,8 +185,8 @@ def test_if_encoding_crashes(*, tmp_path: Path, encoder: EncoderName):
pytest.skip(f"dependencies for {encoder} not installed")
except GatedRepoError:
pytest.skip(f"cannot access gated repo for {encoder}")
except OSError as e:
if "gated repo" in str(e):
except (OSError, ValueError) as e:
if _is_gated_repo_error(e):
pytest.skip(f"cannot access gated repo for {encoder}")
raise

Expand Down
Loading
Loading