diff --git a/autoafids-dev.yml b/autoafids-dev.yml index ba145a6..fd276ae 100644 --- a/autoafids-dev.yml +++ b/autoafids-dev.yml @@ -11,4 +11,4 @@ dependencies: - snakefmt >=0.8.4,<0.9.0 - yamlfix >=1.11.0,<2.0.0 - pygraphviz ==1.7 - - jinja2 >=3.0.3,<4.0.0 + - jinja2 >=3.0.3,<4.0.0 \ No newline at end of file diff --git a/autoafids/config/snakebids.yml b/autoafids/config/snakebids.yml index 0c9b5b8..01eefbc 100644 --- a/autoafids/config/snakebids.yml +++ b/autoafids/config/snakebids.yml @@ -180,6 +180,73 @@ parse_args: default: work type: str + --detect_with_prior: + help: | + (Default) Use MNI-registered prior FCSV to place 5 patches around the + expected AFID location. Fast and accurate when a good MNI registration + is available. + dest: detect_mode + action: store_const + const: prior + default: prior + + --detect_without_prior: + help: | + Use whole-volume sliding-window inference — no prior registration needed. + Slower (scans the full image) but does not depend on MNI registration + quality. Overlap is controlled by --inference-overlap (default 0.5). + dest: detect_mode + action: store_const + const: noprior + + --detect_with_nnlm: + help: | + Use nnLandmark (nnLM) for whole-volume, single-pass AFID detection. + All 32 AFIDs are predicted in one nnU-Net forward pass. + The model is downloaded automatically on first use. + dest: detect_mode + action: store_const + const: nnlm + + --nnlm_fold: + help: "nnLM fold to use for prediction. (default: %(default)s)" + dest: nnlm_fold + default: "0" + type: str + + --nnlm_plans: + help: "nnLM plans identifier. (default: %(default)s)" + dest: nnlm_plans + default: "nnUNetResEncUNetMPlans" + type: str + + --nnlm_checkpoint: + help: "nnLM checkpoint filename inside the model folder. (default: %(default)s)" + dest: nnlm_checkpoint + default: "checkpoint_final.pth" + type: str + + --nnlm_device: + help: "Device for nnLM inference: cuda or cpu. (default: %(default)s)" + dest: nnlm_device + default: "cpu" + type: str + + --inference-overlap: + help: | + Sliding-window overlap for --detect_without_prior (0.0 = no overlap, + 0.5 = 50%% overlap, 0.75 = 75%% overlap). Lower = fewer patches = faster. + Overrides afids_inference.overlap in snakebids.yml. (default: 0.5) + dest: inference_overlap + type: float + + --inference-batch-size: + help: | + Number of patches per model forward pass. Larger = faster but more memory. + Overrides afids_inference.batch_size in snakebids.yml. (default: 7) + dest: inference_batch_size + type: int + #--- workflow specific configuration --- # Nifti template @@ -198,12 +265,61 @@ singularity: # It will be downloaded to ~/.cache/autoafids resource_urls: default: 'files.osf.io/v1/resources/9fptg/providers/osfstorage/?zip=' + # nnLM model checkpoint (whole-volume 32-AFID nnLandmark model) + # TODO: Replace with the actual OSF/Zenodo URL when the model is published + nnlm: 'https://zenodo.org/records/18991189/files/nnlm_model.zip' + +# Sequential inference configuration +enable_sequential_inference: False #Stereotaxy models STN: 'resources/stereotaxy/STN.pkl' cZI: 'resources/stereotaxy/STN.pkl' #change to czi model but dummy here for testing template_fcsv: 'resources/stereotaxy/target_template.fcsv' +# ----------------------------------------------------------------------- +# PyTorch nnUNet inference configuration (used by apply_with_prior.py) +# Set one checkpoint path per fiducial (afid_01 … afid_32). +# patch_size, batch_size, and device are optional (defaults shown). +# ----------------------------------------------------------------------- +afids_inference: + checkpoints: + afid_01: 'resources/afids_cnn_ckpts/afid-01_epoch-820_mae-0.0002.ckpt' + afid_02: 'resources/afids_cnn_ckpts/afid-02_epoch-964_mae-0.0007.ckpt' + afid_03: 'resources/afids_cnn_ckpts/afid-03_epoch-778_mae-0.0008.ckpt' + afid_04: 'resources/afids_cnn_ckpts/afid-04_epoch-529_mae-0.0007.ckpt' + afid_05: 'resources/afids_cnn_ckpts/afid-05_epoch-628_mae-0.0007.ckpt' + afid_06: 'resources/afids_cnn_ckpts/afid-06_epoch-485_mae-0.0009.ckpt' + afid_07: 'resources/afids_cnn_ckpts/afid-07_epoch-599_mae-0.0008.ckpt' + afid_08: 'resources/afids_cnn_ckpts/afid-08_epoch-455_mae-0.0013.ckpt' + afid_09: 'resources/afids_cnn_ckpts/afid-09_epoch-535_mae-0.0010.ckpt' + afid_10: 'resources/afids_cnn_ckpts/afid-10_epoch-431_mae-0.0009.ckpt' + afid_11: 'resources/afids_cnn_ckpts/afid-11_epoch-735_mae-0.0005.ckpt' + afid_12: 'resources/afids_cnn_ckpts/afid-12_epoch-694_mae-0.0007.ckpt' + afid_13: 'resources/afids_cnn_ckpts/afid-13_epoch-763_mae-0.0006.ckpt' + afid_14: 'resources/afids_cnn_ckpts/afid-14_epoch-405_mae-0.0011.ckpt' + afid_15: 'resources/afids_cnn_ckpts/afid-15_epoch-371_mae-0.0013.ckpt' + afid_16: 'resources/afids_cnn_ckpts/afid-16_epoch-454_mae-0.0012.ckpt' + afid_17: 'resources/afids_cnn_ckpts/afid-17_epoch-291_mae-0.0023.ckpt' + afid_18: 'resources/afids_cnn_ckpts/afid-18_epoch-233_mae-0.0016.ckpt' + afid_19: 'resources/afids_cnn_ckpts/afid-19_epoch-626_mae-0.0011.ckpt' + afid_20: 'resources/afids_cnn_ckpts/afid-20_epoch-720_mae-0.0011.ckpt' + afid_21: 'resources/afids_cnn_ckpts/afid-21_epoch-672_mae-0.0009.ckpt' + afid_22: 'resources/afids_cnn_ckpts/afid-22_epoch-647_mae-0.0008.ckpt' + afid_23: 'resources/afids_cnn_ckpts/afid-23_epoch-656_mae-0.0010.ckpt' + afid_24: 'resources/afids_cnn_ckpts/afid-24_epoch-689_mae-0.0007.ckpt' + afid_25: 'resources/afids_cnn_ckpts/afid-25_epoch-513_mae-0.0013.ckpt' + afid_26: 'resources/afids_cnn_ckpts/afid-26_epoch-784_mae-0.0009.ckpt' + afid_27: 'resources/afids_cnn_ckpts/afid-27_epoch-300_mae-0.0014.ckpt' + afid_28: 'resources/afids_cnn_ckpts/afid-28_epoch-351_mae-0.0009.ckpt' + afid_29: 'resources/afids_cnn_ckpts/afid-29_epoch-394_mae-0.0014.ckpt' + afid_30: 'resources/afids_cnn_ckpts/afid-30_epoch-618_mae-0.0012.ckpt' + afid_31: 'resources/afids_cnn_ckpts/afid-31_epoch-573_mae-0.0009.ckpt' + afid_32: 'resources/afids_cnn_ckpts/afid-32_epoch-561_mae-0.0007.ckpt' + patch_size: 64 # cubic patch size in voxels + device: cuda:0 # cpu recommended for parallel mode; cuda:0 for sequential + overlap: 0.5 # sliding-window overlap for --detect_without_prior (0.0–0.75) + plugins.validator.skip: False root: results workdir: null \ No newline at end of file diff --git a/autoafids/resources/afids_cnn_ckpts/afid-01_epoch-820_mae-0.0002.ckpt b/autoafids/resources/afids_cnn_ckpts/afid-01_epoch-820_mae-0.0002.ckpt new file mode 100644 index 0000000..ddd966c Binary files /dev/null and b/autoafids/resources/afids_cnn_ckpts/afid-01_epoch-820_mae-0.0002.ckpt differ diff --git a/autoafids/resources/afids_cnn_ckpts/afid-02_epoch-964_mae-0.0007.ckpt b/autoafids/resources/afids_cnn_ckpts/afid-02_epoch-964_mae-0.0007.ckpt new file mode 100644 index 0000000..3f0f884 Binary files /dev/null and b/autoafids/resources/afids_cnn_ckpts/afid-02_epoch-964_mae-0.0007.ckpt differ diff --git a/autoafids/resources/afids_cnn_ckpts/afid-03_epoch-778_mae-0.0008.ckpt b/autoafids/resources/afids_cnn_ckpts/afid-03_epoch-778_mae-0.0008.ckpt new file mode 100644 index 0000000..4c88441 Binary files /dev/null and b/autoafids/resources/afids_cnn_ckpts/afid-03_epoch-778_mae-0.0008.ckpt differ diff --git a/autoafids/resources/afids_cnn_ckpts/afid-04_epoch-529_mae-0.0007.ckpt b/autoafids/resources/afids_cnn_ckpts/afid-04_epoch-529_mae-0.0007.ckpt new file mode 100755 index 0000000..8fde4fc Binary files /dev/null and b/autoafids/resources/afids_cnn_ckpts/afid-04_epoch-529_mae-0.0007.ckpt differ diff --git a/autoafids/resources/afids_cnn_ckpts/afid-05_epoch-628_mae-0.0007.ckpt b/autoafids/resources/afids_cnn_ckpts/afid-05_epoch-628_mae-0.0007.ckpt new file mode 100755 index 0000000..f23264d Binary files /dev/null and b/autoafids/resources/afids_cnn_ckpts/afid-05_epoch-628_mae-0.0007.ckpt differ diff --git a/autoafids/resources/afids_cnn_ckpts/afid-06_epoch-485_mae-0.0009.ckpt b/autoafids/resources/afids_cnn_ckpts/afid-06_epoch-485_mae-0.0009.ckpt new file mode 100755 index 0000000..7df451b Binary files /dev/null and b/autoafids/resources/afids_cnn_ckpts/afid-06_epoch-485_mae-0.0009.ckpt differ diff --git a/autoafids/resources/afids_cnn_ckpts/afid-07_epoch-599_mae-0.0008.ckpt b/autoafids/resources/afids_cnn_ckpts/afid-07_epoch-599_mae-0.0008.ckpt new file mode 100755 index 0000000..50345d7 Binary files /dev/null and b/autoafids/resources/afids_cnn_ckpts/afid-07_epoch-599_mae-0.0008.ckpt differ diff --git a/autoafids/resources/afids_cnn_ckpts/afid-08_epoch-455_mae-0.0013.ckpt b/autoafids/resources/afids_cnn_ckpts/afid-08_epoch-455_mae-0.0013.ckpt new file mode 100755 index 0000000..99e75c1 Binary files /dev/null and b/autoafids/resources/afids_cnn_ckpts/afid-08_epoch-455_mae-0.0013.ckpt differ diff --git a/autoafids/resources/afids_cnn_ckpts/afid-09_epoch-535_mae-0.0010.ckpt b/autoafids/resources/afids_cnn_ckpts/afid-09_epoch-535_mae-0.0010.ckpt new file mode 100755 index 0000000..116f2a8 Binary files /dev/null and b/autoafids/resources/afids_cnn_ckpts/afid-09_epoch-535_mae-0.0010.ckpt differ diff --git a/autoafids/resources/afids_cnn_ckpts/afid-10_epoch-431_mae-0.0009.ckpt b/autoafids/resources/afids_cnn_ckpts/afid-10_epoch-431_mae-0.0009.ckpt new file mode 100755 index 0000000..6df3231 Binary files /dev/null and b/autoafids/resources/afids_cnn_ckpts/afid-10_epoch-431_mae-0.0009.ckpt differ diff --git a/autoafids/resources/afids_cnn_ckpts/afid-11_epoch-735_mae-0.0005.ckpt b/autoafids/resources/afids_cnn_ckpts/afid-11_epoch-735_mae-0.0005.ckpt new file mode 100755 index 0000000..83d3e75 Binary files /dev/null and b/autoafids/resources/afids_cnn_ckpts/afid-11_epoch-735_mae-0.0005.ckpt differ diff --git a/autoafids/resources/afids_cnn_ckpts/afid-12_epoch-694_mae-0.0007.ckpt b/autoafids/resources/afids_cnn_ckpts/afid-12_epoch-694_mae-0.0007.ckpt new file mode 100755 index 0000000..6cf7e0e Binary files /dev/null and b/autoafids/resources/afids_cnn_ckpts/afid-12_epoch-694_mae-0.0007.ckpt differ diff --git a/autoafids/resources/afids_cnn_ckpts/afid-13_epoch-763_mae-0.0006.ckpt b/autoafids/resources/afids_cnn_ckpts/afid-13_epoch-763_mae-0.0006.ckpt new file mode 100755 index 0000000..1a04256 Binary files /dev/null and b/autoafids/resources/afids_cnn_ckpts/afid-13_epoch-763_mae-0.0006.ckpt differ diff --git a/autoafids/resources/afids_cnn_ckpts/afid-14_epoch-405_mae-0.0011.ckpt b/autoafids/resources/afids_cnn_ckpts/afid-14_epoch-405_mae-0.0011.ckpt new file mode 100755 index 0000000..6da8a90 Binary files /dev/null and b/autoafids/resources/afids_cnn_ckpts/afid-14_epoch-405_mae-0.0011.ckpt differ diff --git a/autoafids/resources/afids_cnn_ckpts/afid-15_epoch-371_mae-0.0013.ckpt b/autoafids/resources/afids_cnn_ckpts/afid-15_epoch-371_mae-0.0013.ckpt new file mode 100755 index 0000000..5ea0d3c Binary files /dev/null and b/autoafids/resources/afids_cnn_ckpts/afid-15_epoch-371_mae-0.0013.ckpt differ diff --git a/autoafids/resources/afids_cnn_ckpts/afid-16_epoch-454_mae-0.0012.ckpt b/autoafids/resources/afids_cnn_ckpts/afid-16_epoch-454_mae-0.0012.ckpt new file mode 100755 index 0000000..9383ad6 Binary files /dev/null and b/autoafids/resources/afids_cnn_ckpts/afid-16_epoch-454_mae-0.0012.ckpt differ diff --git a/autoafids/resources/afids_cnn_ckpts/afid-17_epoch-291_mae-0.0023.ckpt b/autoafids/resources/afids_cnn_ckpts/afid-17_epoch-291_mae-0.0023.ckpt new file mode 100755 index 0000000..b30b6c6 Binary files /dev/null and b/autoafids/resources/afids_cnn_ckpts/afid-17_epoch-291_mae-0.0023.ckpt differ diff --git a/autoafids/resources/afids_cnn_ckpts/afid-18_epoch-233_mae-0.0016.ckpt b/autoafids/resources/afids_cnn_ckpts/afid-18_epoch-233_mae-0.0016.ckpt new file mode 100755 index 0000000..d4e586a Binary files /dev/null and b/autoafids/resources/afids_cnn_ckpts/afid-18_epoch-233_mae-0.0016.ckpt differ diff --git a/autoafids/resources/afids_cnn_ckpts/afid-19_epoch-626_mae-0.0011.ckpt b/autoafids/resources/afids_cnn_ckpts/afid-19_epoch-626_mae-0.0011.ckpt new file mode 100755 index 0000000..8f612b4 Binary files /dev/null and b/autoafids/resources/afids_cnn_ckpts/afid-19_epoch-626_mae-0.0011.ckpt differ diff --git a/autoafids/resources/afids_cnn_ckpts/afid-20_epoch-720_mae-0.0011.ckpt b/autoafids/resources/afids_cnn_ckpts/afid-20_epoch-720_mae-0.0011.ckpt new file mode 100755 index 0000000..99c20d5 Binary files /dev/null and b/autoafids/resources/afids_cnn_ckpts/afid-20_epoch-720_mae-0.0011.ckpt differ diff --git a/autoafids/resources/afids_cnn_ckpts/afid-21_epoch-672_mae-0.0009.ckpt b/autoafids/resources/afids_cnn_ckpts/afid-21_epoch-672_mae-0.0009.ckpt new file mode 100755 index 0000000..8106543 Binary files /dev/null and b/autoafids/resources/afids_cnn_ckpts/afid-21_epoch-672_mae-0.0009.ckpt differ diff --git a/autoafids/resources/afids_cnn_ckpts/afid-22_epoch-647_mae-0.0008.ckpt b/autoafids/resources/afids_cnn_ckpts/afid-22_epoch-647_mae-0.0008.ckpt new file mode 100755 index 0000000..094835b Binary files /dev/null and b/autoafids/resources/afids_cnn_ckpts/afid-22_epoch-647_mae-0.0008.ckpt differ diff --git a/autoafids/resources/afids_cnn_ckpts/afid-23_epoch-656_mae-0.0010.ckpt b/autoafids/resources/afids_cnn_ckpts/afid-23_epoch-656_mae-0.0010.ckpt new file mode 100755 index 0000000..6a639f2 Binary files /dev/null and b/autoafids/resources/afids_cnn_ckpts/afid-23_epoch-656_mae-0.0010.ckpt differ diff --git a/autoafids/resources/afids_cnn_ckpts/afid-24_epoch-689_mae-0.0007.ckpt b/autoafids/resources/afids_cnn_ckpts/afid-24_epoch-689_mae-0.0007.ckpt new file mode 100755 index 0000000..88b5491 Binary files /dev/null and b/autoafids/resources/afids_cnn_ckpts/afid-24_epoch-689_mae-0.0007.ckpt differ diff --git a/autoafids/resources/afids_cnn_ckpts/afid-25_epoch-513_mae-0.0013.ckpt b/autoafids/resources/afids_cnn_ckpts/afid-25_epoch-513_mae-0.0013.ckpt new file mode 100755 index 0000000..6d47c82 Binary files /dev/null and b/autoafids/resources/afids_cnn_ckpts/afid-25_epoch-513_mae-0.0013.ckpt differ diff --git a/autoafids/resources/afids_cnn_ckpts/afid-26_epoch-784_mae-0.0009.ckpt b/autoafids/resources/afids_cnn_ckpts/afid-26_epoch-784_mae-0.0009.ckpt new file mode 100755 index 0000000..417ca96 Binary files /dev/null and b/autoafids/resources/afids_cnn_ckpts/afid-26_epoch-784_mae-0.0009.ckpt differ diff --git a/autoafids/resources/afids_cnn_ckpts/afid-27_epoch-300_mae-0.0014.ckpt b/autoafids/resources/afids_cnn_ckpts/afid-27_epoch-300_mae-0.0014.ckpt new file mode 100755 index 0000000..a0e2cf7 Binary files /dev/null and b/autoafids/resources/afids_cnn_ckpts/afid-27_epoch-300_mae-0.0014.ckpt differ diff --git a/autoafids/resources/afids_cnn_ckpts/afid-28_epoch-351_mae-0.0009.ckpt b/autoafids/resources/afids_cnn_ckpts/afid-28_epoch-351_mae-0.0009.ckpt new file mode 100755 index 0000000..b10247c Binary files /dev/null and b/autoafids/resources/afids_cnn_ckpts/afid-28_epoch-351_mae-0.0009.ckpt differ diff --git a/autoafids/resources/afids_cnn_ckpts/afid-29_epoch-394_mae-0.0014.ckpt b/autoafids/resources/afids_cnn_ckpts/afid-29_epoch-394_mae-0.0014.ckpt new file mode 100755 index 0000000..4d147fd Binary files /dev/null and b/autoafids/resources/afids_cnn_ckpts/afid-29_epoch-394_mae-0.0014.ckpt differ diff --git a/autoafids/resources/afids_cnn_ckpts/afid-30_epoch-618_mae-0.0012.ckpt b/autoafids/resources/afids_cnn_ckpts/afid-30_epoch-618_mae-0.0012.ckpt new file mode 100755 index 0000000..4c7124a Binary files /dev/null and b/autoafids/resources/afids_cnn_ckpts/afid-30_epoch-618_mae-0.0012.ckpt differ diff --git a/autoafids/resources/afids_cnn_ckpts/afid-31_epoch-573_mae-0.0009.ckpt b/autoafids/resources/afids_cnn_ckpts/afid-31_epoch-573_mae-0.0009.ckpt new file mode 100755 index 0000000..531d1be Binary files /dev/null and b/autoafids/resources/afids_cnn_ckpts/afid-31_epoch-573_mae-0.0009.ckpt differ diff --git a/autoafids/resources/afids_cnn_ckpts/afid-32_epoch-561_mae-0.0007.ckpt b/autoafids/resources/afids_cnn_ckpts/afid-32_epoch-561_mae-0.0007.ckpt new file mode 100755 index 0000000..2e7b577 Binary files /dev/null and b/autoafids/resources/afids_cnn_ckpts/afid-32_epoch-561_mae-0.0007.ckpt differ diff --git a/autoafids/workflow/Snakefile b/autoafids/workflow/Snakefile index 644c03f..164bf1d 100644 --- a/autoafids/workflow/Snakefile +++ b/autoafids/workflow/Snakefile @@ -1,8 +1,16 @@ import snakebids -from snakebids import bids, generate_inputs, get_wildcard_constraints +from snakebids import bids, generate_inputs, get_wildcard_constraints, set_bids_spec from appdirs import AppDirs import warnings +# Suppress snakebids "unrecognized entity" warnings for the {afid} wildcard +warnings.filterwarnings( + "ignore", + message="Path generated with unrecognized entities", + category=UserWarning, +) +set_bids_spec("v0_0_0") + try: from autoafids.workflow.lib import ( utils as utils, # Works when run as a package @@ -345,7 +353,19 @@ rule mni2subfids: "scripts/tform_script.py" -include: "rules/cnn.smk" +# ── Detection mode: choose rules to include ────────────────────────────── +if config.get("detect_mode", "prior") == "nnlm": + + # nnLandmark mode: single-pass whole-volume detection + include: "rules/nnlm.smk" + +else: + # Legacy CNN modes (prior / noprior) + # If running on GPU, automatically enable sequential inference to avoid spawning 32 parallel jobs + if config.get("afids_inference", {}).get("device") == "cuda:0": + config["enable_sequential_inference"] = True + + include: "rules/cnn.smk" if config["fidqc"]: @@ -360,13 +380,24 @@ if config["LEAD_DBS_DIR"] or config["FMRIPREP_DIR"]: include: "rules/regqc.smk" +# Select the correct FCSV output descriptor based on inference mode. +# prior → applyfidmodel_gather → desc="afidscnn" +# noprior → applyfidmodel_noprior_gather → desc="afidscnn-noprior" +# nnlm → nnlm_to_fcsv → desc="afidscnn-nnlm" +_fcsv_desc = { + "prior": "afidscnn", + "noprior": "afidscnn-noprior", + "nnlm": "afidscnn-nnlm", +}.get(config.get("detect_mode", "prior"), "afidscnn") + + rule all: input: models=inputs[config["modality"]].expand( bids( root=root, datatype="afids-cnn", - desc="afidscnn", + desc=_fcsv_desc, suffix="afids.fcsv", **inputs[config["modality"]].wildcards, ), diff --git a/autoafids/workflow/envs/nnlm.yaml b/autoafids/workflow/envs/nnlm.yaml new file mode 100644 index 0000000..f8fd8c2 --- /dev/null +++ b/autoafids/workflow/envs/nnlm.yaml @@ -0,0 +1,15 @@ +--- +name: nnlm +channels: + - pytorch + - nvidia + - conda-forge + - defaults +dependencies: + - python=3.12 + - pytorch=2.5.1 + - torchvision=0.20.1 + - pytorch-cuda=12.4 + - pip + - pip: + - nnlandmark @ git+https://github.com/MIC-DKFZ/nnLandmark.git diff --git a/autoafids/workflow/envs/pytorch.yaml b/autoafids/workflow/envs/pytorch.yaml new file mode 100644 index 0000000..0ce4183 --- /dev/null +++ b/autoafids/workflow/envs/pytorch.yaml @@ -0,0 +1,13 @@ +--- +name: pytorch +channels: + - pytorch + - nvidia + - conda-forge +dependencies: + # Match the known-fast training/inference stack (e.g. torch==2.7.1+cu126) + - python=3.10 + - pytorch=2.7.1 + - nibabel # nib.nifti1.load + - numpy # np + - pandas # pd (fcsv loading) diff --git a/autoafids/workflow/rules/cnn.smk b/autoafids/workflow/rules/cnn.smk index 884f011..635e73e 100644 --- a/autoafids/workflow/rules/cnn.smk +++ b/autoafids/workflow/rules/cnn.smk @@ -17,46 +17,353 @@ rule download_cnn_model: " rm model.zip" -rule applyfidmodel: - input: - t1w=lambda wildcards: ( +_AFIDS = [f"{i:02d}" for i in range(1, 33)] # ["01", "02", ..., "32"] + +if config.get("enable_sequential_inference", False): + + rule applyfidmodel_all: + """Run 5-patch inference for ALL 32 AFIDs using MNI-registered prior location""" + input: + t1w=lambda wildcards: ( + bids( + root=work, + datatype="normalize", + desc=chosen_norm_method, + suffix="T1w.nii.gz", + **inputs[config["modality"]].wildcards, + ) + if config["modality"] != "T1w" + else bids( + root=work, + datatype="resample", + desc=chosen_norm_method, + res=config["res"], + suffix="T1w.nii.gz", + **inputs[config["modality"]].wildcards, + ) + ), + prior=bids( + root=work, + datatype="registration", + space="native", + desc="MNI", + suffix="afids.fcsv", + **inputs[config["modality"]].wildcards, + ), + output: + coords=[ + bids( + root=work, + datatype="afids-cnn", + afid=afid, + suffix="coord.txt", + **inputs[config["modality"]].wildcards, + ) + for afid in _AFIDS + ], + probs=[ + bids( + root=work, + datatype="afids-cnn", + afid=afid, + suffix="probmap.nii.gz", + **inputs[config["modality"]].wildcards, + ) + for afid in _AFIDS + ], + fcsv=bids( + root=root, + datatype="afids-cnn", + desc="afidscnn", + suffix="afids.fcsv", + **inputs[config["modality"]].wildcards, + ), + log: bids( + root="logs", + suffix="apply_all_prior.log", + **inputs[config["modality"]].wildcards, + ), + params: + ckpts=lambda wildcards: { + key: str(Path(workflow.basedir).parent / path) + for key, path in config["afids_inference"]["checkpoints"].items() + }, + threads: 1 + conda: + "../envs/pytorch.yaml" + script: + "../scripts/apply_with_prior_all.py" + +else: + + # ── WITH PRIOR (SINGLE) ────────────────────────────────────────────────────── + rule applyfidmodel_single: + """Run 5-patch inference for ONE AFID using MNI-registered prior location.""" + input: + t1w=lambda wildcards: ( + bids( + root=work, + datatype="normalize", + desc=chosen_norm_method, + suffix="T1w.nii.gz", + **inputs[config["modality"]].wildcards, + ) + if config["modality"] != "T1w" + else bids( + root=work, + datatype="resample", + desc=chosen_norm_method, + res=config["res"], + suffix="T1w.nii.gz", + **inputs[config["modality"]].wildcards, + ) + ), + prior=bids( root=work, - datatype="normalize", - desc=chosen_norm_method, - suffix="T1w.nii.gz", + datatype="registration", + space="native", + desc="MNI", + suffix="afids.fcsv", **inputs[config["modality"]].wildcards, - ) - if config["modality"] != "T1w" - else bids( + ), + output: + coord=bids( root=work, - datatype="resample", - desc=chosen_norm_method, - res=config["res"], - suffix="T1w.nii.gz", + datatype="afids-cnn", + afid="{afid}", + suffix="coord.txt", **inputs[config["modality"]].wildcards, - ) - ), - prior=bids( - root=work, - datatype="registration", - space="native", - desc="MNI", - suffix="afids.fcsv", - **inputs[config["modality"]].wildcards, - ), - model_dir=Path(download_dir) / "models", - output: - fcsv=bids( - root=root, - datatype="afids-cnn", - desc="afidscnn", - suffix="afids.fcsv", - **inputs[config["modality"]].wildcards, - ), - log: - bids(root="logs", suffix="landmark.log", **inputs[config["modality"]].wildcards), - conda: - "../envs/tensorflow.yaml" - script: - "../scripts/apply.py" + ), + prob=bids( + root=work, + datatype="afids-cnn", + afid="{afid}", + suffix="probmap.nii.gz", + **inputs[config["modality"]].wildcards, + ), + log: + bids( + root="logs", + afid="{afid}", + suffix="landmark.log", + **inputs[config["modality"]].wildcards, + ), + wildcard_constraints: + afid=r"\d{2}", + params: + ckpt_path=lambda wildcards: str( + Path(workflow.basedir).parent + / config["afids_inference"]["checkpoints"][ + f"afid_{int(wildcards.afid):02d}" + ] + ), + threads: 1 + conda: + "../envs/pytorch.yaml" + script: + "../scripts/apply_with_prior_single.py" + + rule applyfidmodel_gather: + """Collect all 32 per-AFID coord files and write combined FCSV.""" + input: + coords=lambda wildcards: expand( + bids( + root=work, + datatype="afids-cnn", + afid="{afid}", + suffix="coord.txt", + **{ + k: getattr(wildcards, k) + for k in inputs[config["modality"]].wildcards + }, + ), + afid=_AFIDS, + ), + output: + fcsv=bids( + root=root, + datatype="afids-cnn", + desc="afidscnn", + suffix="afids.fcsv", + **inputs[config["modality"]].wildcards, + ), + log: + bids( + root="logs", + suffix="gather.log", + **inputs[config["modality"]].wildcards, + ), + threads: 1 + conda: + "../envs/pytorch.yaml" + script: + "../scripts/gather_afids.py" + + +# --------------------------------------------------------------------------- + +if config.get("enable_sequential_inference", False): + + rule applyfidmodel_noprior_all: + """Whole-volume sliding-window inference for ALL 32 AFIDs (no prior needed)""" + input: + t1w=lambda wildcards: ( + bids( + root=work, + datatype="normalize", + desc=chosen_norm_method, + suffix="T1w.nii.gz", + **inputs[config["modality"]].wildcards, + ) + if config["modality"] != "T1w" + else bids( + root=work, + datatype="resample", + desc=chosen_norm_method, + res=config["res"], + suffix="T1w.nii.gz", + **inputs[config["modality"]].wildcards, + ) + ), + output: + coords=[ + bids( + root=work, + datatype="afids-cnn-noprior", + afid=afid, + suffix="coord.txt", + **inputs[config["modality"]].wildcards, + ) + for afid in _AFIDS + ], + probs=[ + bids( + root=work, + datatype="afids-cnn-noprior", + afid=afid, + suffix="probmap.nii.gz", + **inputs[config["modality"]].wildcards, + ) + for afid in _AFIDS + ], + fcsv=bids( + root=root, + datatype="afids-cnn", + desc="afidscnn-noprior", + suffix="afids.fcsv", + **inputs[config["modality"]].wildcards, + ), + log: + bids( + root="logs", + suffix="noprior-apply_all.log", + **inputs[config["modality"]].wildcards, + ), + params: + ckpts=lambda wildcards: { + key: str(Path(workflow.basedir).parent / path) + for key, path in config["afids_inference"]["checkpoints"].items() + }, + threads: 1 + conda: + "../envs/pytorch.yaml" + script: + "../scripts/apply_noprior_all.py" + +else: + + # ── WITHOUT PRIOR (SINGLE) ────────────────────────────────────────────────── + rule applyfidmodel_noprior_single: + """Whole-volume sliding-window inference for ONE AFID — no prior needed.""" + input: + t1w=lambda wildcards: ( + bids( + root=work, + datatype="normalize", + desc=chosen_norm_method, + suffix="T1w.nii.gz", + **inputs[config["modality"]].wildcards, + ) + if config["modality"] != "T1w" + else bids( + root=work, + datatype="resample", + desc=chosen_norm_method, + res=config["res"], + suffix="T1w.nii.gz", + **inputs[config["modality"]].wildcards, + ) + ), + output: + coord=bids( + root=work, + datatype="afids-cnn-noprior", + afid="{afid}", + suffix="coord.txt", + **inputs[config["modality"]].wildcards, + ), + prob=bids( + root=work, + datatype="afids-cnn-noprior", + afid="{afid}", + suffix="probmap.nii.gz", + **inputs[config["modality"]].wildcards, + ), + log: + bids( + root="logs", + afid="{afid}", + suffix="noprior-landmark.log", + **inputs[config["modality"]].wildcards, + ), + wildcard_constraints: + afid=r"\d{2}", + params: + ckpt_path=lambda wildcards: str( + Path(workflow.basedir).parent + / config["afids_inference"]["checkpoints"][ + f"afid_{int(wildcards.afid):02d}" + ] + ), + threads: 1 + conda: + "../envs/pytorch.yaml" + script: + "../scripts/apply_noprior_single.py" + + rule applyfidmodel_noprior_gather: + """Collect all 32 no-prior coord files and write the combined FCSV.""" + input: + coords=lambda wildcards: expand( + bids( + root=work, + datatype="afids-cnn-noprior", + afid="{afid}", + suffix="coord.txt", + **{ + k: getattr(wildcards, k) + for k in inputs[config["modality"]].wildcards + }, + ), + afid=_AFIDS, + ), + output: + fcsv=bids( + root=root, + datatype="afids-cnn", + desc="afidscnn-noprior", + suffix="afids.fcsv", + **inputs[config["modality"]].wildcards, + ), + log: + bids( + root="logs", + suffix="noprior-gather.log", + **inputs[config["modality"]].wildcards, + ), + threads: 1 + conda: + "../envs/pytorch.yaml" + script: + "../scripts/gather_afids.py" diff --git a/autoafids/workflow/rules/nnlm.smk b/autoafids/workflow/rules/nnlm.smk new file mode 100644 index 0000000..fdbddff --- /dev/null +++ b/autoafids/workflow/rules/nnlm.smk @@ -0,0 +1,154 @@ +# nnlm.smk ───────────────────────────────────────────────────────────────── +# Rules for nnLandmark (nnLM) whole-volume AFID detection. +# Activated when `--detect_with_nnlm` is passed on the CLI. +# +# Rule graph: +# download_nnlm_model → run_nnlm → nnlm_to_fcsv → rule all +# ───────────────────────────────────────────────────────────────────────────── + +NNLM_MODEL_DIR = Path(download_dir) / "models" / "nnlm" + + +# ── 0. Download trained model ──────────────────────────────────────────────── +rule download_nnlm_model: + """Download the pre-trained nnLM model zip and extract it to the cache dir.""" + params: + url=config["resource_urls"].get("nnlm", ""), + output: + model_dir=directory(NNLM_MODEL_DIR), + log: + bids( + root="logs", + suffix="download_nnlm_model.log", + ), + shell: + "mkdir -p {output.model_dir} && " + "wget -q {params.url} -O nnlm_model.zip > {log} 2>&1 && " + "unzip -q -d {output.model_dir} nnlm_model.zip >> {log} 2>&1 && " + "rm nnlm_model.zip" + + +# ── Helper: choose the right preprocessed T1w based on modality ────────────── +def _nnlm_t1w(wildcards): + if config["modality"] != "T1w": + return bids( + root=work, + datatype="normalize", + desc=chosen_norm_method, + suffix="T1w.nii.gz", + **inputs[config["modality"]].wildcards, + ) + return bids( + root=work, + datatype="resample", + desc=chosen_norm_method, + res=config["res"], + suffix="T1w.nii.gz", + **inputs[config["modality"]].wildcards, + ) + + +# ── 1. nnLM inference (all 32 AFIDs in one forward pass) ───────────────────── +rule run_nnlm: + """Run nnLandmark whole-volume inference — all 32 AFIDs in a single forward pass.""" + input: + t1w=_nnlm_t1w, + model_dir=NNLM_MODEL_DIR, + output: + seg=bids( + root=work, + datatype="afids-nnlm", + suffix="dseg.nii.gz", + **inputs[config["modality"]].wildcards, + ), + coords_json=bids( + root=work, + datatype="afids-nnlm", + suffix="coords.json", + **inputs[config["modality"]].wildcards, + ), + params: + dataset_id="800", + fold=config.get("nnlm_fold", "0"), + plans=config.get("nnlm_plans", "nnUNetResEncUNetMPlans"), + checkpoint=config.get("nnlm_checkpoint", "checkpoint_final.pth"), + device=config.get("nnlm_device", "cuda"), + tmpdir=lambda wildcards: f"/tmp/nnlm_{wildcards.subject}", + conda: + "../envs/nnlm.yaml" + threads: 4 + resources: + gpus=lambda wildcards: 1 if config.get("nnlm_device", "cuda") == "cuda" else 0, + mem_mb=16000, + log: + bids( + root="logs", + suffix="nnlm.log", + **inputs[config["modality"]].wildcards, + ), + shell: + """ + set -euo pipefail + + # Ensure log and output directories exist + mkdir -p $(dirname {log}) + mkdir -p $(dirname {output.seg}) + + TMPDIR={params.tmpdir} + mkdir -p $TMPDIR/input_dir $TMPDIR/output_dir + + # nnLM input naming convention: {{case}}_0000.nii.gz + cp {input.t1w} $TMPDIR/input_dir/{wildcards.subject}_0000.nii.gz + + # Point nnLM env vars to the downloaded model + export nnLM_raw=dummy + export nnLM_preprocessed=dummy + export nnLM_results={input.model_dir} + + python -m nnlandmark.inference.nnLandmark.predict_from_raw_data \ + -i $TMPDIR/input_dir \ + -o $TMPDIR/output_dir \ + -d {params.dataset_id} \ + -c 3d_fullres \ + -tr nnLandmark \ + -p {params.plans} \ + -f {params.fold} \ + -chk {params.checkpoint} \ + -device {params.device} \ + -npp 0 -nps 0 \ + > {log} 2>&1 + + # Collect outputs + cp $TMPDIR/output_dir/{wildcards.subject}.nii.gz {output.seg} + cp $TMPDIR/output_dir/{wildcards.subject}.json {output.coords_json} + + # Cleanup temp dir + rm -rf $TMPDIR + """ + + +# ── 2. Convert JSON voxel coords → RAS world coords → Slicer FCSV ──────────── +rule nnlm_to_fcsv: + """Convert nnLM per-subject JSON (voxel space) to a Slicer-compatible FCSV (RAS mm).""" + input: + coords_json=bids( + root=work, + datatype="afids-nnlm", + suffix="coords.json", + **inputs[config["modality"]].wildcards, + ), + t1w=_nnlm_t1w, + output: + fcsv=bids( + root=root, + datatype="afids-cnn", + desc="afidscnn-nnlm", + suffix="afids.fcsv", + **inputs[config["modality"]].wildcards, + ), + params: + fcsv_template=str(Path(workflow.basedir).parent / "resources" / "dummy.fcsv"), + conda: + "../envs/nibabel.yaml" + script: + "../scripts/nnlm_to_fcsv.py" diff --git a/autoafids/workflow/scripts/apply_noprior_all.py b/autoafids/workflow/scripts/apply_noprior_all.py new file mode 100644 index 0000000..185ea17 --- /dev/null +++ b/autoafids/workflow/scripts/apply_noprior_all.py @@ -0,0 +1,444 @@ +# ruff: noqa +""" +apply_noprior_single.py +======================= + +Snakemake script: whole-volume sliding-window PyTorch inference for ONE AFID. +Does NOT require a prior FCSV — the model scans the entire image. + +Called by the ``applyfidmodel_noprior_single`` rule which has wildcard ``{afid}``. +Snakemake runs ``--cores N`` instances concurrently — that is the parallelism. + +Snakemake I/O +------------- + input: + t1w – preprocessed NIfTI (.nii.gz) [no prior needed] + output: + coord – plain-text file: "x y z\\n" for this AFID + prob – full-volume probability map as NIfTI (.nii.gz) + wildcards: + afid – zero-padded AFID number, e.g. "01" +""" + +import time +import warnings +from pathlib import Path +from typing import List, Optional, Tuple + +import nibabel as nib +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +import os + +# Each Snakemake job = 1 process. Limit PyTorch threads so that +# --cores N means N truly parallel jobs without over-subscription thrashing. +if "snakemake" in globals() and hasattr(snakemake, "threads"): + n_threads = str(snakemake.threads) + os.environ["OMP_NUM_THREADS"] = n_threads + os.environ["MKL_NUM_THREADS"] = n_threads + os.environ["OPENBLAS_NUM_THREADS"] = n_threads + os.environ["VECLIB_MAXIMUM_THREADS"] = n_threads + os.environ["NUMEXPR_NUM_THREADS"] = n_threads + torch.set_num_threads(snakemake.threads) + +from numpy.typing import NDArray + +warnings.filterwarnings("ignore") + + +# =================================================================== +# FCSV helpers +# =================================================================== + + +def fid_voxel2world(fid_voxel: NDArray, nii_affine: NDArray) -> NDArray: + return (nii_affine[:3, :3].dot(fid_voxel) + nii_affine[:3, 3]).astype(float) + + +# =================================================================== +# nnUNet architecture — names MUST match checkpoint keys exactly +# (encoder_blocks, decoder_blocks, deep_supervision_heads) +# =================================================================== + + +class ConvBlock(nn.Module): + def __init__(self, in_channels: int, out_channels: int) -> None: + super().__init__() + self.conv_block = nn.Sequential( + nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1, bias=True), + nn.InstanceNorm3d(out_channels, affine=True), + nn.LeakyReLU(negative_slope=0.01, inplace=True), + nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1, bias=True), + nn.InstanceNorm3d(out_channels, affine=True), + nn.LeakyReLU(negative_slope=0.01, inplace=True), + ) + + def forward(self, x): + return self.conv_block(x) + + +class EncoderBlock(nn.Module): + def __init__(self, in_channels: int, out_channels: int) -> None: + super().__init__() + self.conv_block = ConvBlock(in_channels, out_channels) + self.downsample = nn.Conv3d( + out_channels, out_channels, kernel_size=3, stride=2, padding=1 + ) + + def forward(self, x): + skip = self.conv_block(x) + return skip, self.downsample(skip) + + +class DecoderBlock(nn.Module): + def __init__(self, in_channels: int, out_channels: int) -> None: + super().__init__() + self.upsample = nn.ConvTranspose3d( + in_channels, in_channels // 2, kernel_size=2, stride=2 + ) + self.conv_block = ConvBlock(in_channels, out_channels) + + def forward(self, x, skip): + return self.conv_block(torch.cat([self.upsample(x), skip], 1)) + + +class nnUNet(nn.Module): + def __init__( + self, in_channels: int, out_channels: int, features: Optional[List[int]] = None + ) -> None: + super().__init__() + if features is None: + features = [32, 64, 128, 256, 320] + self.encoder_blocks = nn.ModuleList() + prev_ch = in_channels + for feat in features[:-1]: + self.encoder_blocks.append(EncoderBlock(prev_ch, feat)) + prev_ch = feat + self.bottleneck = ConvBlock(features[-2], features[-1]) + self.decoder_blocks = nn.ModuleList() + rev = list(reversed(features)) + for i in range(len(rev) - 1): + self.decoder_blocks.append(DecoderBlock(rev[i], rev[i + 1])) + self.deep_supervision_heads = nn.ModuleList() + for feat in reversed(features[:-1]): + self.deep_supervision_heads.append( + nn.Conv3d(feat, out_channels, kernel_size=1) + ) + + def forward(self, x): + skips = [] + for enc in self.encoder_blocks: + skip, x = enc(x) + skips.append(skip) + x = self.bottleneck(x) + skips = list(reversed(skips)) + for i, dec in enumerate(self.decoder_blocks): + x = dec(x, skips[i]) + return self.deep_supervision_heads[-1](x) + + +class nnUNet_VanillaUNet(nnUNet): + def __init__( + self, + in_channels: int = 1, + out_channels: int = 1, + features: Optional[List[int]] = None, + ) -> None: + super().__init__(in_channels, out_channels, features) + + +def _load_model( + ckpt_path: str, features: List[int], device: torch.device +) -> nnUNet_VanillaUNet: + model = nnUNet_VanillaUNet(in_channels=1, out_channels=1, features=features) + ckpt = torch.load(ckpt_path, map_location="cpu") + raw_sd = ckpt.get("state_dict", ckpt) + cleaned_sd = { + (k.replace("model.", "", 1) if k.startswith("model.") else k): v + for k, v in raw_sd.items() + } + model.load_state_dict(cleaned_sd, strict=False) + return model.to(device).eval() + + +# =================================================================== +# Sliding-window inference (no prior) +# =================================================================== + + +def infer_noprior_single_afid( + fid: int, + ckpt_path: str, + img: nib.nifti1.Nifti1Image, + patch_size: int = 64, + batch_size: int = 7, + overlap: float = 0.5, + device_str: str = "cpu", + features: Optional[List[int]] = None, +) -> Tuple[NDArray, torch.Tensor]: + if features is None: + features = [16, 32, 64] + + device = torch.device( + device_str if (device_str == "cpu" or torch.cuda.is_available()) else "cpu" + ) + + # ---- Image ---- + image_np = img.get_fdata().astype(np.float32) + if not np.isfinite(image_np).all(): + finite = np.isfinite(image_np) + lo = image_np[finite].min() if finite.any() else 0.0 + hi = image_np[finite].max() if finite.any() else 1.0 + image_np = np.nan_to_num(image_np, nan=0.0, posinf=hi, neginf=lo) + image_tensor = torch.from_numpy(image_np) + affine = img.affine + vol_shape = image_tensor.shape + D, H, W = vol_shape + + # ---- Load model ---- + t0 = time.perf_counter() + model = _load_model(ckpt_path, features, device) + t_load = time.perf_counter() - t0 + + # ---- Gaussian map ---- + ps = patch_size + + def _g1d(n): + s = n * 0.125 + c = n // 2 + x = torch.arange(n, dtype=torch.float32) + return torch.exp(-((x - c) ** 2) / (2 * s**2)) + + gmap = torch.einsum("z,y,x->zyx", _g1d(ps), _g1d(ps), _g1d(ps)) + gmap = gmap / gmap.max() + + # ---- Sliding-window patch coordinates ---- + t0 = time.perf_counter() + step = max(1, int(ps * (1.0 - overlap))) + patch_coords = [] + for z in range(0, D, step): + for y in range(0, H, step): + for x in range(0, W, step): + ze = min(z + ps, D) + ye = min(y + ps, H) + xe = min(x + ps, W) + zs = max(0, ze - ps) + ys = max(0, ye - ps) + xs = max(0, xe - ps) + patch_coords.append((slice(zs, ze), slice(ys, ye), slice(xs, xe))) + + num_patches = len(patch_coords) + + # ---- Extract & normalise all patches ---- + model_inputs = [] + for slcs in patch_coords: + patch = image_tensor[slcs[0], slcs[1], slcs[2]] + pz = max(0, ps - patch.shape[0]) + py = max(0, ps - patch.shape[1]) + px_p = max(0, ps - patch.shape[2]) + if pz or py or px_p: + patch = F.pad(patch, (0, px_p, 0, py, 0, pz)) + lo, hi = patch.min(), patch.max() + patch = (patch - lo) / (hi - lo + 1e-8) if hi > lo else patch + model_inputs.append(patch.unsqueeze(0)) # [1,D,H,W] + t_patch = time.perf_counter() - t0 + + # ---- Inference ---- + t0 = time.perf_counter() + predictions = [] + for i in range(0, len(model_inputs), batch_size): + chunk = torch.stack(model_inputs[i : i + batch_size]).to(device) + with torch.inference_mode(): + preds = model(chunk) + predictions.extend(p.cpu() for p in preds) + del chunk, preds + t_infer = time.perf_counter() - t0 + + # ---- Gaussian-weighted reconstruction ---- + t0 = time.perf_counter() + out = torch.zeros(vol_shape, dtype=torch.float32) + cnt = torch.zeros(vol_shape, dtype=torch.float32) + for pred, (zs, ys, xs) in zip(predictions, patch_coords): + az = zs.stop - zs.start + ay = ys.stop - ys.start + ax = xs.stop - xs.start + # Crop Gaussian if patch was at a boundary (< full patch_size) + if (az, ay, ax) == (ps, ps, ps): + w = gmap + else: + zo = max(0, (ps - az) // 2) + yo = max(0, (ps - ay) // 2) + xo = max(0, (ps - ax) // 2) + w = gmap[zo : zo + az, yo : yo + ay, xo : xo + ax] + pc = pred[0, :az, :ay, :ax] + out[zs, ys, xs] += pc * w + cnt[zs, ys, xs] += w + out /= cnt.clamp(min=1e-8) + t_recon = time.perf_counter() - t0 + + # ---- Peak → world ---- + idx = np.unravel_index(np.argmax(out.numpy()), vol_shape) + pred_vox = tuple(int(c) for c in idx) + pred_world = fid_voxel2world(np.array(pred_vox, dtype=float), affine) + + t_total = t_load + t_patch + t_infer + t_recon + print( + f" [noprior] AFID {fid:02d} patches={num_patches} " + f"load={t_load:.2f}s patch={t_patch:.2f}s " + f"infer={t_infer:.2f}s recon={t_recon:.2f}s total={t_total:.2f}s | " + f"pred_vox={pred_vox} " + f"world=[{pred_world[0]:.1f}, {pred_world[1]:.1f}, {pred_world[2]:.1f}]", + flush=True, + ) + return pred_world, out # out = full-volume probability map [D,H,W] + + +# =================================================================== +# Snakemake entry point (All 32 AFIDs) +# =================================================================== + +import os +from pathlib import Path +import csv + +AFIDS_FIELDNAMES = [ + "id", + "x", + "y", + "z", + "ow", + "ox", + "oy", + "oz", + "vis", + "sel", + "lock", + "label", + "desc", + "associatedNodeID", +] + +AFID_DESCRIPTIONS = [ + "AC", + "PC", + "Infracollicular Sulcus", + "PMJ", + "Superior IPF", + "Right Superior LMS", + "Left Superior LMS", + "Right Inferior LMS", + "Left Inferior LMS", + "Culmen", + "Intermammillary Sulcus", + "Right Mammilary Body", + "Left Mammilary Body", + "Pineal Gland", + "Right LV at AC", + "Left LV at AC", + "Right LV at PC", + "Left LV at PC", + "Genu of CC", + "Splenium of CC", + "Right AL Temporal Horn", + "Left AL Tempral Horn", + "R. Sup. AM Temporal Horn", + "L. Sup. AM Temporal Horn", + "R Inf. AM Temp Horn", + "L Inf. AM Temp Horn", + "Right IG Origin", + "Left IG Origin", + "R Ventral Occipital Horn", + "L Ventral Occipital Horn", + "R Olfactory Fundus", + "L Olfactory Fundus", +] + + +def write_combined_fcsv(afid_coords, fcsv_output): + header_lines = [ + "# Markups fiducial file version = 4.10\n", + "# CoordinateSystem = 0\n", + "# columns = id,x,y,z,ow,ox,oy,oz,vis,sel,lock,label,desc,associatedNodeID\n", + ] + rows = [] + for lbl in range(1, 33): + c = afid_coords.get(lbl, np.array([0.0, 0.0, 0.0])) + desc = AFID_DESCRIPTIONS[lbl - 1] if lbl <= len(AFID_DESCRIPTIONS) else "" + rows.append( + { + "id": lbl, + "x": c[0], + "y": c[1], + "z": c[2], + "ow": "0.000", + "ox": "0.000", + "oy": "0.000", + "oz": "1.000", + "vis": 1, + "sel": 1, + "lock": 1, + "label": lbl, + "desc": desc, + "associatedNodeID": "", + } + ) + with Path(fcsv_output).open("w", encoding="utf-8", newline="") as f: + for line in header_lines: + f.write(line) + writer = csv.DictWriter(f, fieldnames=AFIDS_FIELDNAMES) + for row in rows: + writer.writerow(row) + + +cfg_block = snakemake.config.get("afids_inference", {}) +if not cfg_block: + raise ValueError("Missing 'afids_inference' block in snakebids.yml.") + +checkpoints = snakemake.params.ckpts +img = nib.nifti1.load(snakemake.input.t1w) + +combined_coords = {} + +for fid in range(1, 33): + key = f"afid_{fid:02d}" + ckpt_path = checkpoints.get(key) + out_coord = snakemake.output.coords[fid - 1] + out_prob = snakemake.output.probs[fid - 1] + + if not ckpt_path or not Path(ckpt_path).exists(): + print(f" [SKIP] AFID {fid:02d} checkpoint not found: {ckpt_path}") + combined_coords[fid] = np.array([0.0, 0.0, 0.0]) + Path(out_coord).parent.mkdir(parents=True, exist_ok=True) + with open(out_coord, "w") as f: + f.write("0.0 0.0 0.0\n") + continue + + pred_world, prob_map = infer_noprior_single_afid( + fid=fid, + ckpt_path=ckpt_path, + img=img, + overlap=( + snakemake.config["inference_overlap"] + if snakemake.config.get("inference_overlap") is not None + else cfg_block.get("overlap", 0.5) + ), + patch_size=cfg_block.get("patch_size", 64), + batch_size=snakemake.config.get("inference_batch_size") or 7, + device_str=cfg_block.get("device", "cpu"), + ) + combined_coords[fid] = pred_world + + Path(out_coord).parent.mkdir(parents=True, exist_ok=True) + with open(out_coord, "w") as f: + f.write(f"{pred_world[0]} {pred_world[1]} {pred_world[2]}\n") + + Path(out_prob).parent.mkdir(parents=True, exist_ok=True) + nib.save( + nib.Nifti1Image(prob_map.numpy(), img.affine, img.header), + out_prob, + ) + +write_combined_fcsv(combined_coords, snakemake.output.fcsv) diff --git a/autoafids/workflow/scripts/apply_noprior_single.py b/autoafids/workflow/scripts/apply_noprior_single.py new file mode 100644 index 0000000..ec9a0cf --- /dev/null +++ b/autoafids/workflow/scripts/apply_noprior_single.py @@ -0,0 +1,341 @@ +# ruff: noqa +""" +apply_noprior_single.py +======================= + +Snakemake script: whole-volume sliding-window PyTorch inference for ONE AFID. +Does NOT require a prior FCSV — the model scans the entire image. + +Called by the ``applyfidmodel_noprior_single`` rule which has wildcard ``{afid}``. +Snakemake runs ``--cores N`` instances concurrently — that is the parallelism. + +Snakemake I/O +------------- + input: + t1w – preprocessed NIfTI (.nii.gz) [no prior needed] + output: + coord – plain-text file: "x y z\\n" for this AFID + prob – full-volume probability map as NIfTI (.nii.gz) + wildcards: + afid – zero-padded AFID number, e.g. "01" +""" + +import os + +# Set environment variables BEFORE any other imports to ensure they take effect +if "snakemake" in globals() and hasattr(snakemake, "threads"): + n_threads = str(snakemake.threads) + os.environ["OMP_NUM_THREADS"] = n_threads + os.environ["MKL_NUM_THREADS"] = n_threads + os.environ["OPENBLAS_NUM_THREADS"] = n_threads + os.environ["VECLIB_MAXIMUM_THREADS"] = n_threads + os.environ["NUMEXPR_NUM_THREADS"] = n_threads + +import time +import warnings +from pathlib import Path +from typing import List, Optional, Tuple + +import nibabel as nib +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +if "snakemake" in globals() and hasattr(snakemake, "threads"): + torch.set_num_threads(snakemake.threads) + +from numpy.typing import NDArray + +warnings.filterwarnings("ignore") + + +# =================================================================== +# FCSV helpers +# =================================================================== + + +def fid_voxel2world(fid_voxel: NDArray, nii_affine: NDArray) -> NDArray: + return (nii_affine[:3, :3].dot(fid_voxel) + nii_affine[:3, 3]).astype(float) + + +# =================================================================== +# nnUNet architecture — names MUST match checkpoint keys exactly +# (encoder_blocks, decoder_blocks, deep_supervision_heads) +# =================================================================== + + +class ConvBlock(nn.Module): + def __init__(self, in_channels: int, out_channels: int) -> None: + super().__init__() + self.conv_block = nn.Sequential( + nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1, bias=True), + nn.InstanceNorm3d(out_channels, affine=True), + nn.LeakyReLU(negative_slope=0.01, inplace=True), + nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1, bias=True), + nn.InstanceNorm3d(out_channels, affine=True), + nn.LeakyReLU(negative_slope=0.01, inplace=True), + ) + + def forward(self, x): + return self.conv_block(x) + + +class EncoderBlock(nn.Module): + def __init__(self, in_channels: int, out_channels: int) -> None: + super().__init__() + self.conv_block = ConvBlock(in_channels, out_channels) + self.downsample = nn.Conv3d( + out_channels, out_channels, kernel_size=3, stride=2, padding=1 + ) + + def forward(self, x): + skip = self.conv_block(x) + return skip, self.downsample(skip) + + +class DecoderBlock(nn.Module): + def __init__(self, in_channels: int, out_channels: int) -> None: + super().__init__() + self.upsample = nn.ConvTranspose3d( + in_channels, in_channels // 2, kernel_size=2, stride=2 + ) + self.conv_block = ConvBlock(in_channels, out_channels) + + def forward(self, x, skip): + return self.conv_block(torch.cat([self.upsample(x), skip], 1)) + + +class nnUNet(nn.Module): + def __init__( + self, in_channels: int, out_channels: int, features: Optional[List[int]] = None + ) -> None: + super().__init__() + if features is None: + features = [32, 64, 128, 256, 320] + self.encoder_blocks = nn.ModuleList() + prev_ch = in_channels + for feat in features[:-1]: + self.encoder_blocks.append(EncoderBlock(prev_ch, feat)) + prev_ch = feat + self.bottleneck = ConvBlock(features[-2], features[-1]) + self.decoder_blocks = nn.ModuleList() + rev = list(reversed(features)) + for i in range(len(rev) - 1): + self.decoder_blocks.append(DecoderBlock(rev[i], rev[i + 1])) + self.deep_supervision_heads = nn.ModuleList() + for feat in reversed(features[:-1]): + self.deep_supervision_heads.append( + nn.Conv3d(feat, out_channels, kernel_size=1) + ) + + def forward(self, x): + skips = [] + for enc in self.encoder_blocks: + skip, x = enc(x) + skips.append(skip) + x = self.bottleneck(x) + skips = list(reversed(skips)) + for i, dec in enumerate(self.decoder_blocks): + x = dec(x, skips[i]) + return self.deep_supervision_heads[-1](x) + + +class nnUNet_VanillaUNet(nnUNet): + def __init__( + self, + in_channels: int = 1, + out_channels: int = 1, + features: Optional[List[int]] = None, + ) -> None: + super().__init__(in_channels, out_channels, features) + + +def _load_model( + ckpt_path: str, features: List[int], device: torch.device +) -> nnUNet_VanillaUNet: + model = nnUNet_VanillaUNet(in_channels=1, out_channels=1, features=features) + ckpt = torch.load(ckpt_path, map_location="cpu") + raw_sd = ckpt.get("state_dict", ckpt) + cleaned_sd = { + (k.replace("model.", "", 1) if k.startswith("model.") else k): v + for k, v in raw_sd.items() + } + model.load_state_dict(cleaned_sd, strict=False) + return model.to(device).eval() + + +# =================================================================== +# Sliding-window inference (no prior) +# =================================================================== + + +def infer_noprior_single_afid( + fid: int, + ckpt_path: str, + img: nib.nifti1.Nifti1Image, + patch_size: int = 64, + batch_size: int = 7, + overlap: float = 0.5, + device_str: str = "cpu", + features: Optional[List[int]] = None, +) -> Tuple[NDArray, torch.Tensor]: + if features is None: + features = [16, 32, 64] + + device = torch.device( + device_str if (device_str == "cpu" or torch.cuda.is_available()) else "cpu" + ) + + # ---- Image ---- + image_np = img.get_fdata().astype(np.float32) + if not np.isfinite(image_np).all(): + finite = np.isfinite(image_np) + lo = image_np[finite].min() if finite.any() else 0.0 + hi = image_np[finite].max() if finite.any() else 1.0 + image_np = np.nan_to_num(image_np, nan=0.0, posinf=hi, neginf=lo) + image_tensor = torch.from_numpy(image_np) + affine = img.affine + vol_shape = image_tensor.shape + D, H, W = vol_shape + + # ---- Load model ---- + t0 = time.perf_counter() + model = _load_model(ckpt_path, features, device) + t_load = time.perf_counter() - t0 + + # ---- Gaussian map ---- + ps = patch_size + + def _g1d(n): + s = n * 0.125 + c = n // 2 + x = torch.arange(n, dtype=torch.float32) + return torch.exp(-((x - c) ** 2) / (2 * s**2)) + + gmap = torch.einsum("z,y,x->zyx", _g1d(ps), _g1d(ps), _g1d(ps)) + gmap = gmap / gmap.max() + + # ---- Sliding-window patch coordinates ---- + t0 = time.perf_counter() + step = max(1, int(ps * (1.0 - overlap))) + patch_coords = [] + for z in range(0, D, step): + for y in range(0, H, step): + for x in range(0, W, step): + ze = min(z + ps, D) + ye = min(y + ps, H) + xe = min(x + ps, W) + zs = max(0, ze - ps) + ys = max(0, ye - ps) + xs = max(0, xe - ps) + patch_coords.append((slice(zs, ze), slice(ys, ye), slice(xs, xe))) + + num_patches = len(patch_coords) + + # ---- Extract & normalise all patches ---- + model_inputs = [] + for slcs in patch_coords: + patch = image_tensor[slcs[0], slcs[1], slcs[2]] + pz = max(0, ps - patch.shape[0]) + py = max(0, ps - patch.shape[1]) + px_p = max(0, ps - patch.shape[2]) + if pz or py or px_p: + patch = F.pad(patch, (0, px_p, 0, py, 0, pz)) + lo, hi = patch.min(), patch.max() + patch = (patch - lo) / (hi - lo + 1e-8) if hi > lo else patch + model_inputs.append(patch.unsqueeze(0)) # [1,D,H,W] + t_patch = time.perf_counter() - t0 + + # ---- Inference ---- + t0 = time.perf_counter() + predictions = [] + for i in range(0, len(model_inputs), batch_size): + chunk = torch.stack(model_inputs[i : i + batch_size]).to(device) + with torch.inference_mode(): + preds = model(chunk) + predictions.extend(p.cpu() for p in preds) + del chunk, preds + t_infer = time.perf_counter() - t0 + + # ---- Gaussian-weighted reconstruction ---- + t0 = time.perf_counter() + out = torch.zeros(vol_shape, dtype=torch.float32) + cnt = torch.zeros(vol_shape, dtype=torch.float32) + for pred, (zs, ys, xs) in zip(predictions, patch_coords): + az = zs.stop - zs.start + ay = ys.stop - ys.start + ax = xs.stop - xs.start + # Crop Gaussian if patch was at a boundary (< full patch_size) + if (az, ay, ax) == (ps, ps, ps): + w = gmap + else: + zo = max(0, (ps - az) // 2) + yo = max(0, (ps - ay) // 2) + xo = max(0, (ps - ax) // 2) + w = gmap[zo : zo + az, yo : yo + ay, xo : xo + ax] + pc = pred[0, :az, :ay, :ax] + out[zs, ys, xs] += pc * w + cnt[zs, ys, xs] += w + out /= cnt.clamp(min=1e-8) + t_recon = time.perf_counter() - t0 + + # ---- Peak → world ---- + idx = np.unravel_index(np.argmax(out.numpy()), vol_shape) + pred_vox = tuple(int(c) for c in idx) + pred_world = fid_voxel2world(np.array(pred_vox, dtype=float), affine) + + t_total = t_load + t_patch + t_infer + t_recon + print( + f" [noprior] AFID {fid:02d} patches={num_patches} " + f"load={t_load:.2f}s patch={t_patch:.2f}s " + f"infer={t_infer:.2f}s recon={t_recon:.2f}s total={t_total:.2f}s | " + f"pred_vox={pred_vox} " + f"world=[{pred_world[0]:.1f}, {pred_world[1]:.1f}, {pred_world[2]:.1f}]", + flush=True, + ) + return pred_world, out # out = full-volume probability map [D,H,W] + + +# =================================================================== +# Snakemake entry point +# =================================================================== + +fid = int(snakemake.wildcards.afid) # noqa: F821 + +cfg_block = snakemake.config.get("afids_inference", {}) # noqa: F821 +if not cfg_block: + raise ValueError("Missing 'afids_inference' block in snakebids.yml.") + +# Checkpoint path resolved by the rule via workflow.basedir +ckpt_path = snakemake.params.ckpt_path # noqa: F821 +if not Path(ckpt_path).exists(): + raise FileNotFoundError(f"Checkpoint for AFID {fid:02d} not found: {ckpt_path}") + +img = nib.nifti1.load(snakemake.input.t1w) # noqa: F821 + +pred_world, prob_map = infer_noprior_single_afid( + fid=fid, + ckpt_path=ckpt_path, + img=img, + patch_size=cfg_block.get("patch_size", 64), + batch_size=snakemake.config.get("inference_batch_size") or 7, # noqa: F821 + overlap=( + snakemake.config["inference_overlap"] + if snakemake.config.get("inference_overlap") is not None + else cfg_block.get("overlap", 0.5) + ), # noqa: F821 + device_str=cfg_block.get("device", "cpu"), +) + +# Write x y z to a plain-text coord file +Path(snakemake.output.coord).parent.mkdir(parents=True, exist_ok=True) # noqa: F821 +with open(snakemake.output.coord, "w") as f: # noqa: F821 + f.write(f"{pred_world[0]} {pred_world[1]} {pred_world[2]}\n") + +# Save probability map as NIfTI (same affine/header as input image) +Path(snakemake.output.prob).parent.mkdir(parents=True, exist_ok=True) # noqa: F821 +nib.save( # noqa: F821 + nib.Nifti1Image(prob_map.numpy(), img.affine, img.header), + snakemake.output.prob, # noqa: F821 +) diff --git a/autoafids/workflow/scripts/apply_with_prior_all.py b/autoafids/workflow/scripts/apply_with_prior_all.py new file mode 100644 index 0000000..c48dc9d --- /dev/null +++ b/autoafids/workflow/scripts/apply_with_prior_all.py @@ -0,0 +1,617 @@ +# ruff: noqa +""" +apply_with_prior_single.py +========================== + +Snakemake script: PyTorch inference for ONE AFID. + +Called by the ``applyfidmodel_single`` rule which has wildcard ``{afid}`` +(zero-padded two digits: 01–32). Snakemake runs ``--cores N`` instances of +this rule concurrently — that is the parallelism mechanism. No Python +multiprocessing is used here. + +Snakemake I/O +------------- + input: + t1w – preprocessed NIfTI (.nii.gz) + prior – per-subject MNI-registered FCSV with 32 prior locations + output: + coord – plain-text file: "x y z\\n" for this AFID + wildcards: + afid – zero-padded AFID number, e.g. "01" +""" + +import os + +# Set environment variables BEFORE any other imports to ensure they take effect +if "snakemake" in globals() and hasattr(snakemake, "threads"): + n_threads = str(snakemake.threads) + os.environ["OMP_NUM_THREADS"] = n_threads + os.environ["MKL_NUM_THREADS"] = n_threads + os.environ["OPENBLAS_NUM_THREADS"] = n_threads + os.environ["VECLIB_MAXIMUM_THREADS"] = n_threads + os.environ["NUMEXPR_NUM_THREADS"] = n_threads + +import time +import warnings +from pathlib import Path +from typing import List, Optional, Tuple + +import nibabel as nib +import numpy as np +import pandas as pd +import torch +import torch.nn as nn +import torch.nn.functional as F +from numpy.typing import NDArray + +if "snakemake" in globals() and hasattr(snakemake, "threads"): + # print(f"DEBUG: Snakemake threads = {snakemake.threads}") + # print(f"DEBUG: torch.get_num_threads() (before) = {torch.get_num_threads()}") + torch.set_num_threads(snakemake.threads) + # print(f"DEBUG: torch.get_num_threads() (after) = {torch.get_num_threads()}") + +warnings.filterwarnings("ignore") + + +# =================================================================== +# FCSV helpers +# =================================================================== + + +def load_fcsv(fcsv_path) -> pd.DataFrame: + return pd.read_csv(fcsv_path, sep=",", header=2) + + +def get_fid(fcsv_df: pd.DataFrame, fid_label: int) -> NDArray: + return fcsv_df.loc[fid_label - 1, ["x", "y", "z"]].to_numpy( + dtype="single", copy=True + ) + + +def fid_world2voxel(fid_world: NDArray, nii_affine: NDArray) -> NDArray: + inv = np.linalg.inv(nii_affine) + return np.rint(inv[:3, :3].dot(fid_world) + inv[:3, 3]).astype(int) + + +def fid_voxel2world(fid_voxel: NDArray, nii_affine: NDArray) -> NDArray: + return (nii_affine[:3, :3].dot(fid_voxel) + nii_affine[:3, 3]).astype(float) + + +# =================================================================== +# nnUNet architecture — names MUST match checkpoint keys exactly +# (encoder_blocks, decoder_blocks, deep_supervision_heads) +# =================================================================== + + +class ConvBlock(nn.Module): + def __init__(self, in_channels: int, out_channels: int) -> None: + super().__init__() + self.conv_block = nn.Sequential( + nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1, bias=True), + nn.InstanceNorm3d(out_channels, affine=True), + nn.LeakyReLU(negative_slope=0.01, inplace=True), + nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1, bias=True), + nn.InstanceNorm3d(out_channels, affine=True), + nn.LeakyReLU(negative_slope=0.01, inplace=True), + ) + + def forward(self, x): + return self.conv_block(x) + + +class EncoderBlock(nn.Module): + def __init__(self, in_channels: int, out_channels: int) -> None: + super().__init__() + self.conv_block = ConvBlock(in_channels, out_channels) + self.downsample = nn.Conv3d( + out_channels, out_channels, kernel_size=3, stride=2, padding=1 + ) + + def forward(self, x): + skip = self.conv_block(x) + return skip, self.downsample(skip) + + +class DecoderBlock(nn.Module): + def __init__(self, in_channels: int, out_channels: int) -> None: + super().__init__() + self.upsample = nn.ConvTranspose3d( + in_channels, in_channels // 2, kernel_size=2, stride=2 + ) + self.conv_block = ConvBlock(in_channels, out_channels) + + def forward(self, x, skip): + return self.conv_block(torch.cat([self.upsample(x), skip], 1)) + + +class nnUNet(nn.Module): + def __init__( + self, in_channels: int, out_channels: int, features: Optional[List[int]] = None + ) -> None: + super().__init__() + if features is None: + features = [32, 64, 128, 256, 320] + self.encoder_blocks = nn.ModuleList() + prev_ch = in_channels + for feat in features[:-1]: + self.encoder_blocks.append(EncoderBlock(prev_ch, feat)) + prev_ch = feat + self.bottleneck = ConvBlock(features[-2], features[-1]) + self.decoder_blocks = nn.ModuleList() + rev = list(reversed(features)) + for i in range(len(rev) - 1): + self.decoder_blocks.append(DecoderBlock(rev[i], rev[i + 1])) + self.deep_supervision_heads = nn.ModuleList() + for feat in reversed(features[:-1]): + self.deep_supervision_heads.append( + nn.Conv3d(feat, out_channels, kernel_size=1) + ) + + def forward(self, x): + skips = [] + for enc in self.encoder_blocks: + skip, x = enc(x) + skips.append(skip) + x = self.bottleneck(x) + skips = list(reversed(skips)) + for i, dec in enumerate(self.decoder_blocks): + x = dec(x, skips[i]) + return self.deep_supervision_heads[-1](x) + + +class nnUNet_VanillaUNet(nnUNet): + """Thin wrapper matching the Lightning checkpoint's class.""" + + def __init__( + self, + in_channels: int = 1, + out_channels: int = 1, + features: Optional[List[int]] = None, + ) -> None: + super().__init__(in_channels, out_channels, features) + + +def _load_model( + ckpt_path: str, features: List[int], device: torch.device +) -> nnUNet_VanillaUNet: + model = nnUNet_VanillaUNet(in_channels=1, out_channels=1, features=features) + ckpt = torch.load(ckpt_path, map_location="cpu") + raw_sd = ckpt.get("state_dict", ckpt) + cleaned_sd = { + (k.replace("model.", "", 1) if k.startswith("model.") else k): v + for k, v in raw_sd.items() + } + model.load_state_dict(cleaned_sd, strict=False) + return model.to(device).eval() + + +# =================================================================== +# Single-AFID inference +# =================================================================== + + +def infer_single_afid( + fid: int, + ckpt_path: str, + img: nib.nifti1.Nifti1Image, + prior_fcsv_path, + patch_size: int = 64, + batch_size: int = 5, + device_str: str = "cpu", + features: Optional[List[int]] = None, +) -> Tuple[NDArray, torch.Tensor]: + """Run 5-patch inference for a single AFID. + + Returns + ------- + pred_world : NDArray – predicted world coords [x, y, z] + prob_map : Tensor – full-volume Gaussian-weighted probability map [D,H,W] + """ + if features is None: + features = [16, 32, 64] + + device = torch.device( + device_str if (device_str == "cpu" or torch.cuda.is_available()) else "cpu" + ) + + # ---- Image ---- + image_np = img.get_fdata().astype(np.float32) + if not np.isfinite(image_np).all(): + finite = np.isfinite(image_np) + lo = image_np[finite].min() if finite.any() else 0.0 + hi = image_np[finite].max() if finite.any() else 1.0 + image_np = np.nan_to_num(image_np, nan=0.0, posinf=hi, neginf=lo) + image_tensor = torch.from_numpy(image_np) + affine = img.affine + vol_shape = image_tensor.shape + D, H, W = vol_shape + + # ---- Prior centre ---- + prior_df = load_fcsv(prior_fcsv_path) + prior_world = get_fid(prior_df, fid) + prior_vox = fid_world2voxel(prior_world, affine) + cz = int(np.clip(prior_vox[0], 0, D - 1)) + cy = int(np.clip(prior_vox[1], 0, H - 1)) + cx = int(np.clip(prior_vox[2], 0, W - 1)) + + # ---- Load model ---- + t0 = time.perf_counter() + model = _load_model(ckpt_path, features, device) + t_load = time.perf_counter() - t0 + + # ---- Gaussian map ---- + ps = patch_size + + def _g1d(n): + s = n * 0.125 + c = n // 2 + x = torch.arange(n, dtype=torch.float32) + return torch.exp(-((x - c) ** 2) / (2 * s**2)) + + gmap = torch.einsum("z,y,x->zyx", _g1d(ps), _g1d(ps), _g1d(ps)) + gmap = gmap / gmap.max() + + # ---- 7 patches: centre + ±x + ±y + ±z ---- + half = ps // 2 + offsets = [ + (0, 0, 0), # centre + (0, 0, -half), # left (x−) + (0, 0, +half), # right (x+) + (0, -half, 0), # up (y−) + (0, +half, 0), # down (y+) + (-half, 0, 0), # in (z−) + (+half, 0, 0), # out (z+) + ] + patch_coords = [] + for dz, dy, dx in offsets: + zs = max(0, min(cz + dz - half, D - ps)) + ys = max(0, min(cy + dy - half, H - ps)) + xs = max(0, min(cx + dx - half, W - ps)) + patch_coords.append( + (slice(zs, zs + ps), slice(ys, ys + ps), slice(xs, xs + ps)) + ) + + # ---- Extract & normalise ---- + t0 = time.perf_counter() + model_inputs = [] + for slcs in patch_coords: + patch = image_tensor[slcs[0], slcs[1], slcs[2]] + pz = max(0, ps - patch.shape[0]) + py = max(0, ps - patch.shape[1]) + px_p = max(0, ps - patch.shape[2]) + if pz or py or px_p: + patch = F.pad(patch, (0, px_p, 0, py, 0, pz)) + lo, hi = patch.min(), patch.max() + patch = (patch - lo) / (hi - lo + 1e-8) if hi > lo else patch + model_inputs.append(patch.unsqueeze(0)) + t_patch = time.perf_counter() - t0 + + # ---- Inference ---- + t0 = time.perf_counter() + cpu0 = time.process_time() + predictions = [] + for i in range(0, len(model_inputs), batch_size): + chunk = torch.stack(model_inputs[i : i + batch_size]).to(device) + with torch.inference_mode(): + preds = model(chunk) + predictions.extend(p.cpu() for p in preds) + del chunk, preds + t_infer = time.perf_counter() - t0 + cpu_infer = time.process_time() - cpu0 + + # ---- Reconstruct ---- + t0 = time.perf_counter() + out = torch.zeros(vol_shape, dtype=torch.float32) + cnt = torch.zeros(vol_shape, dtype=torch.float32) + for pred, (zs, ys, xs) in zip(predictions, patch_coords): + out[zs, ys, xs] += pred[0] * gmap + cnt[zs, ys, xs] += gmap + out /= cnt.clamp(min=1e-8) + t_recon = time.perf_counter() - t0 + + # ---- Peak → world ---- + idx = np.unravel_index(np.argmax(out.numpy()), vol_shape) + pred_vox = tuple(int(c) for c in idx) + pred_world = fid_voxel2world(np.array(pred_vox, dtype=float), affine) + + t_total = t_load + t_patch + t_infer + t_recon + print( + f" [OK] AFID {fid:02d} " + f"load={t_load:.2f}s patch={t_patch:.2f}s " + f"infer={t_infer:.2f}s (cpu={cpu_infer:.2f}s) " + f"recon={t_recon:.2f}s total={t_total:.2f}s | " + f"pred_vox={pred_vox} " + f"world=[{pred_world[0]:.1f}, {pred_world[1]:.1f}, {pred_world[2]:.1f}] " + f"[dev={device}, threads={torch.get_num_threads()}]" + ) + return pred_world, out # out = full-volume probability map [D,H,W] + + +def _infer_single_from_state( + fid: int, + image_tensor: torch.Tensor, + affine: NDArray, + prior_df: pd.DataFrame, + model: nn.Module, + gmap: torch.Tensor, + patch_size: int, + batch_size: int, + device: torch.device, +) -> Tuple[NDArray, torch.Tensor]: + """Core inference that reuses precomputed image tensor, prior, model, and gmap.""" + vol_shape = tuple(int(d) for d in image_tensor.shape) + D, H, W = vol_shape + + # ---- Prior centre ---- + prior_world = get_fid(prior_df, fid) + prior_vox = fid_world2voxel(prior_world, affine) + cz = int(np.clip(prior_vox[0], 0, D - 1)) + cy = int(np.clip(prior_vox[1], 0, H - 1)) + cx = int(np.clip(prior_vox[2], 0, W - 1)) + + # ---- 7 patches: centre + ±x + ±y + ±z ---- + ps = patch_size + half = ps // 2 + offsets = [ + (0, 0, 0), + (0, 0, -half), + (0, 0, +half), + (0, -half, 0), + (0, +half, 0), + (-half, 0, 0), + (+half, 0, 0), + ] + patch_coords = [] + for dz, dy, dx in offsets: + zs = max(0, min(cz + dz - half, D - ps)) + ys = max(0, min(cy + dy - half, H - ps)) + xs = max(0, min(cx + dx - half, W - ps)) + patch_coords.append( + (slice(zs, zs + ps), slice(ys, ys + ps), slice(xs, xs + ps)) + ) + + # ---- Extract & normalise ---- + t0 = time.perf_counter() + model_inputs = [] + for slcs in patch_coords: + patch = image_tensor[slcs[0], slcs[1], slcs[2]] + pz = max(0, ps - patch.shape[0]) + py = max(0, ps - patch.shape[1]) + px_p = max(0, ps - patch.shape[2]) + if pz or py or px_p: + patch = F.pad(patch, (0, px_p, 0, py, 0, pz)) + lo, hi = patch.min(), patch.max() + patch = (patch - lo) / (hi - lo + 1e-8) if hi > lo else patch + model_inputs.append(patch.unsqueeze(0)) + t_patch = time.perf_counter() - t0 + + # ---- Inference ---- + t0 = time.perf_counter() + cpu0 = time.process_time() + predictions = [] + for i in range(0, len(model_inputs), batch_size): + chunk = torch.stack(model_inputs[i : i + batch_size]).to(device) + with torch.inference_mode(): + preds = model(chunk) + predictions.extend(p.cpu() for p in preds) + del chunk, preds + t_infer = time.perf_counter() - t0 + cpu_infer = time.process_time() - cpu0 + + # ---- Reconstruct ---- + t0 = time.perf_counter() + out = torch.zeros(vol_shape, dtype=torch.float32) + cnt = torch.zeros(vol_shape, dtype=torch.float32) + for pred, (zs, ys, xs) in zip(predictions, patch_coords): + out[zs, ys, xs] += pred[0] * gmap + cnt[zs, ys, xs] += gmap + out /= cnt.clamp(min=1e-8) + t_recon = time.perf_counter() - t0 + + # ---- Peak → world ---- + idx = np.unravel_index(np.argmax(out.numpy()), vol_shape) + pred_vox = tuple(int(c) for c in idx) + pred_world = fid_voxel2world(np.array(pred_vox, dtype=float), affine) + + t_total = t_patch + t_infer + t_recon + print( + f" [OK] AFID {fid:02d} " + f"patch={t_patch:.2f}s " + f"infer={t_infer:.2f}s (cpu={cpu_infer:.2f}s) " + f"recon={t_recon:.2f}s total={t_total:.2f}s | " + f"pred_vox={pred_vox} " + f"world=[{pred_world[0]:.1f}, {pred_world[1]:.1f}, {pred_world[2]:.1f}] " + f"[dev={device}, threads={torch.get_num_threads()}]" + ) + return pred_world, out + + +# =================================================================== +# Snakemake entry point (All 32 AFIDs) +# =================================================================== + +import os +from pathlib import Path +import csv + +AFIDS_FIELDNAMES = [ + "id", + "x", + "y", + "z", + "ow", + "ox", + "oy", + "oz", + "vis", + "sel", + "lock", + "label", + "desc", + "associatedNodeID", +] + +AFID_DESCRIPTIONS = [ + "AC", + "PC", + "Infracollicular Sulcus", + "PMJ", + "Superior IPF", + "Right Superior LMS", + "Left Superior LMS", + "Right Inferior LMS", + "Left Inferior LMS", + "Culmen", + "Intermammillary Sulcus", + "Right Mammilary Body", + "Left Mammilary Body", + "Pineal Gland", + "Right LV at AC", + "Left LV at AC", + "Right LV at PC", + "Left LV at PC", + "Genu of CC", + "Splenium of CC", + "Right AL Temporal Horn", + "Left AL Tempral Horn", + "R. Sup. AM Temporal Horn", + "L. Sup. AM Temporal Horn", + "R Inf. AM Temp Horn", + "L Inf. AM Temp Horn", + "Right IG Origin", + "Left IG Origin", + "R Ventral Occipital Horn", + "L Ventral Occipital Horn", + "R Olfactory Fundus", + "L Olfactory Fundus", +] + + +def write_combined_fcsv(afid_coords, fcsv_output): + header_lines = [ + "# Markups fiducial file version = 4.10\n", + "# CoordinateSystem = 0\n", + "# columns = id,x,y,z,ow,ox,oy,oz,vis,sel,lock,label,desc,associatedNodeID\n", + ] + rows = [] + for lbl in range(1, 33): + c = afid_coords.get(lbl, np.array([0.0, 0.0, 0.0])) + desc = AFID_DESCRIPTIONS[lbl - 1] if lbl <= len(AFID_DESCRIPTIONS) else "" + rows.append( + { + "id": lbl, + "x": c[0], + "y": c[1], + "z": c[2], + "ow": "0.000", + "ox": "0.000", + "oy": "0.000", + "oz": "1.000", + "vis": 1, + "sel": 1, + "lock": 1, + "label": lbl, + "desc": desc, + "associatedNodeID": "", + } + ) + with Path(fcsv_output).open("w", encoding="utf-8", newline="") as f: + for line in header_lines: + f.write(line) + writer = csv.DictWriter(f, fieldnames=AFIDS_FIELDNAMES) + for row in rows: + writer.writerow(row) + + +cfg_block = snakemake.config.get("afids_inference", {}) +if not cfg_block: + raise ValueError("Missing 'afids_inference' block in snakebids.yml.") + +checkpoints = snakemake.params.ckpts +img = nib.nifti1.load(snakemake.input.t1w) +prior = snakemake.input.prior + +# Precompute image tensor and clean NaNs/Infs once per subject +image_np = img.get_fdata().astype(np.float32) +if not np.isfinite(image_np).all(): + finite = np.isfinite(image_np) + lo = image_np[finite].min() if finite.any() else 0.0 + hi = image_np[finite].max() if finite.any() else 1.0 + image_np = np.nan_to_num(image_np, nan=0.0, posinf=hi, neginf=lo) +image_tensor = torch.from_numpy(image_np) +affine = img.affine + +# Load prior FCSV once +prior_df = load_fcsv(prior) + +# Precompute Gaussian map once +ps = int(cfg_block.get("patch_size", 64)) + + +def _g1d_all(n: int) -> torch.Tensor: + s = n * 0.125 + c = n // 2 + x = torch.arange(n, dtype=torch.float32) + return torch.exp(-((x - c) ** 2) / (2 * s**2)) + + +gmap = torch.einsum("z,y,x->zyx", _g1d_all(ps), _g1d_all(ps), _g1d_all(ps)) +gmap = gmap / gmap.max() + +batch_size = snakemake.config.get("inference_batch_size") or 7 +device_str = cfg_block.get("device", "cpu") +device = torch.device( + device_str if (device_str == "cpu" or torch.cuda.is_available()) else "cpu" +) + +# Optional: cache models by checkpoint path +features = cfg_block.get("features", [16, 32, 64]) +model_cache = {} + +combined_coords = {} + +for fid in range(1, 33): + key = f"afid_{fid:02d}" + ckpt_path = checkpoints.get(key) + out_coord = snakemake.output.coords[fid - 1] + out_prob = snakemake.output.probs[fid - 1] + + if not ckpt_path or not Path(ckpt_path).exists(): + print(f" [SKIP] AFID {fid:02d} checkpoint not found: {ckpt_path}") + combined_coords[fid] = np.array([0.0, 0.0, 0.0]) + Path(out_coord).parent.mkdir(parents=True, exist_ok=True) + with open(out_coord, "w") as f: + f.write("0.0 0.0 0.0\n") + continue + + # Load or reuse model for this checkpoint + model = model_cache.get(ckpt_path) + if model is None: + model = _load_model(ckpt_path, features, device) + model_cache[ckpt_path] = model + + pred_world, prob_map = _infer_single_from_state( + fid=fid, + image_tensor=image_tensor, + affine=affine, + prior_df=prior_df, + model=model, + gmap=gmap, + patch_size=ps, + batch_size=batch_size, + device=device, + ) + combined_coords[fid] = pred_world + + Path(out_coord).parent.mkdir(parents=True, exist_ok=True) + with open(out_coord, "w") as f: + f.write(f"{pred_world[0]} {pred_world[1]} {pred_world[2]}\n") + + Path(out_prob).parent.mkdir(parents=True, exist_ok=True) + nib.save( + nib.Nifti1Image(prob_map.numpy(), affine, img.header), + out_prob, + ) + +write_combined_fcsv(combined_coords, snakemake.output.fcsv) diff --git a/autoafids/workflow/scripts/apply_with_prior_single.py b/autoafids/workflow/scripts/apply_with_prior_single.py new file mode 100644 index 0000000..363066e --- /dev/null +++ b/autoafids/workflow/scripts/apply_with_prior_single.py @@ -0,0 +1,369 @@ +# ruff: noqa +""" +apply_with_prior_single.py +========================== + +Snakemake script: PyTorch inference for ONE AFID. + +Called by the ``applyfidmodel_single`` rule which has wildcard ``{afid}`` +(zero-padded two digits: 01–32). Snakemake runs ``--cores N`` instances of +this rule concurrently — that is the parallelism mechanism. No Python +multiprocessing is used here. + +Snakemake I/O +------------- + input: + t1w – preprocessed NIfTI (.nii.gz) + prior – per-subject MNI-registered FCSV with 32 prior locations + output: + coord – plain-text file: "x y z\\n" for this AFID + wildcards: + afid – zero-padded AFID number, e.g. "01" +""" + +import os +import sys + +# Set environment variables BEFORE any other imports to ensure they take effect +if "snakemake" in globals() and hasattr(snakemake, "threads"): + n_threads = str(snakemake.threads) + os.environ["OMP_NUM_THREADS"] = n_threads + os.environ["MKL_NUM_THREADS"] = n_threads + os.environ["OPENBLAS_NUM_THREADS"] = n_threads + os.environ["VECLIB_MAXIMUM_THREADS"] = n_threads + os.environ["NUMEXPR_NUM_THREADS"] = n_threads + +import time +import warnings +from pathlib import Path +from typing import List, Optional, Tuple + +import nibabel as nib +import numpy as np +import pandas as pd +import torch +import torch.nn as nn +import torch.nn.functional as F +from numpy.typing import NDArray + +if "snakemake" in globals() and hasattr(snakemake, "threads"): + print(f"DEBUG: Snakemake threads = {snakemake.threads}") + print(f"DEBUG: torch.get_num_threads() (before) = {torch.get_num_threads()}") + torch.set_num_threads(snakemake.threads) + print(f"DEBUG: torch.get_num_threads() (after) = {torch.get_num_threads()}") +else: + print("DEBUG: Snakemake object or threads attribute not found.") + +warnings.filterwarnings("ignore") + + +# =================================================================== +# FCSV helpers +# =================================================================== + + +def load_fcsv(fcsv_path) -> pd.DataFrame: + return pd.read_csv(fcsv_path, sep=",", header=2) + + +def get_fid(fcsv_df: pd.DataFrame, fid_label: int) -> NDArray: + return fcsv_df.loc[fid_label - 1, ["x", "y", "z"]].to_numpy( + dtype="single", copy=True + ) + + +def fid_world2voxel(fid_world: NDArray, nii_affine: NDArray) -> NDArray: + inv = np.linalg.inv(nii_affine) + return np.rint(inv[:3, :3].dot(fid_world) + inv[:3, 3]).astype(int) + + +def fid_voxel2world(fid_voxel: NDArray, nii_affine: NDArray) -> NDArray: + return (nii_affine[:3, :3].dot(fid_voxel) + nii_affine[:3, 3]).astype(float) + + +# =================================================================== +# nnUNet architecture — names MUST match checkpoint keys exactly +# (encoder_blocks, decoder_blocks, deep_supervision_heads) +# =================================================================== + + +class ConvBlock(nn.Module): + def __init__(self, in_channels: int, out_channels: int) -> None: + super().__init__() + self.conv_block = nn.Sequential( + nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1, bias=True), + nn.InstanceNorm3d(out_channels, affine=True), + nn.LeakyReLU(negative_slope=0.01, inplace=True), + nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1, bias=True), + nn.InstanceNorm3d(out_channels, affine=True), + nn.LeakyReLU(negative_slope=0.01, inplace=True), + ) + + def forward(self, x): + return self.conv_block(x) + + +class EncoderBlock(nn.Module): + def __init__(self, in_channels: int, out_channels: int) -> None: + super().__init__() + self.conv_block = ConvBlock(in_channels, out_channels) + self.downsample = nn.Conv3d( + out_channels, out_channels, kernel_size=3, stride=2, padding=1 + ) + + def forward(self, x): + skip = self.conv_block(x) + return skip, self.downsample(skip) + + +class DecoderBlock(nn.Module): + def __init__(self, in_channels: int, out_channels: int) -> None: + super().__init__() + self.upsample = nn.ConvTranspose3d( + in_channels, in_channels // 2, kernel_size=2, stride=2 + ) + self.conv_block = ConvBlock(in_channels, out_channels) + + def forward(self, x, skip): + return self.conv_block(torch.cat([self.upsample(x), skip], 1)) + + +class nnUNet(nn.Module): + def __init__( + self, in_channels: int, out_channels: int, features: Optional[List[int]] = None + ) -> None: + super().__init__() + if features is None: + features = [32, 64, 128, 256, 320] + self.encoder_blocks = nn.ModuleList() + prev_ch = in_channels + for feat in features[:-1]: + self.encoder_blocks.append(EncoderBlock(prev_ch, feat)) + prev_ch = feat + self.bottleneck = ConvBlock(features[-2], features[-1]) + self.decoder_blocks = nn.ModuleList() + rev = list(reversed(features)) + for i in range(len(rev) - 1): + self.decoder_blocks.append(DecoderBlock(rev[i], rev[i + 1])) + self.deep_supervision_heads = nn.ModuleList() + for feat in reversed(features[:-1]): + self.deep_supervision_heads.append( + nn.Conv3d(feat, out_channels, kernel_size=1) + ) + + def forward(self, x): + skips = [] + for enc in self.encoder_blocks: + skip, x = enc(x) + skips.append(skip) + x = self.bottleneck(x) + skips = list(reversed(skips)) + for i, dec in enumerate(self.decoder_blocks): + x = dec(x, skips[i]) + return self.deep_supervision_heads[-1](x) + + +class nnUNet_VanillaUNet(nnUNet): + """Thin wrapper matching the Lightning checkpoint's class.""" + + def __init__( + self, + in_channels: int = 1, + out_channels: int = 1, + features: Optional[List[int]] = None, + ) -> None: + super().__init__(in_channels, out_channels, features) + + +def _load_model( + ckpt_path: str, features: List[int], device: torch.device +) -> nnUNet_VanillaUNet: + model = nnUNet_VanillaUNet(in_channels=1, out_channels=1, features=features) + ckpt = torch.load(ckpt_path, map_location="cpu") + raw_sd = ckpt.get("state_dict", ckpt) + cleaned_sd = { + (k.replace("model.", "", 1) if k.startswith("model.") else k): v + for k, v in raw_sd.items() + } + model.load_state_dict(cleaned_sd, strict=False) + return model.to(device).eval() + + +# =================================================================== +# Single-AFID inference +# =================================================================== + + +def infer_single_afid( + fid: int, + ckpt_path: str, + img: nib.nifti1.Nifti1Image, + prior_fcsv_path, + patch_size: int = 64, + batch_size: int = 7, + device_str: str = "cpu", + features: Optional[List[int]] = None, +) -> Tuple[NDArray, torch.Tensor]: + """Run 5-patch inference for a single AFID. + + Returns + ------- + pred_world : NDArray – predicted world coords [x, y, z] + prob_map : Tensor – full-volume Gaussian-weighted probability map [D,H,W] + """ + if features is None: + features = [16, 32, 64] + + device = torch.device( + device_str if (device_str == "cpu" or torch.cuda.is_available()) else "cpu" + ) + + # ---- Image ---- + image_np = img.get_fdata().astype(np.float32) + if not np.isfinite(image_np).all(): + finite = np.isfinite(image_np) + lo = image_np[finite].min() if finite.any() else 0.0 + hi = image_np[finite].max() if finite.any() else 1.0 + image_np = np.nan_to_num(image_np, nan=0.0, posinf=hi, neginf=lo) + image_tensor = torch.from_numpy(image_np) + affine = img.affine + vol_shape = image_tensor.shape + D, H, W = vol_shape + + # ---- Prior centre ---- + prior_df = load_fcsv(prior_fcsv_path) + prior_world = get_fid(prior_df, fid) + prior_vox = fid_world2voxel(prior_world, affine) + cz = int(np.clip(prior_vox[0], 0, D - 1)) + cy = int(np.clip(prior_vox[1], 0, H - 1)) + cx = int(np.clip(prior_vox[2], 0, W - 1)) + + # ---- Load model ---- + t0 = time.perf_counter() + model = _load_model(ckpt_path, features, device) + t_load = time.perf_counter() - t0 + + # ---- Gaussian map ---- + ps = patch_size + + def _g1d(n): + s = n * 0.125 + c = n // 2 + x = torch.arange(n, dtype=torch.float32) + return torch.exp(-((x - c) ** 2) / (2 * s**2)) + + gmap = torch.einsum("z,y,x->zyx", _g1d(ps), _g1d(ps), _g1d(ps)) + gmap = gmap / gmap.max() + + # ---- 7 patches: centre + ±x + ±y + ±z ---- + half = ps // 2 + offsets = [ + (0, 0, 0), # centre + (0, 0, -half), # left (x−) + (0, 0, +half), # right (x+) + (0, -half, 0), # up (y−) + (0, +half, 0), # down (y+) + (-half, 0, 0), # in (z−) + (+half, 0, 0), # out (z+) + ] + patch_coords = [] + for dz, dy, dx in offsets: + zs = max(0, min(cz + dz - half, D - ps)) + ys = max(0, min(cy + dy - half, H - ps)) + xs = max(0, min(cx + dx - half, W - ps)) + patch_coords.append( + (slice(zs, zs + ps), slice(ys, ys + ps), slice(xs, xs + ps)) + ) + + # ---- Extract & normalise ---- + t0 = time.perf_counter() + model_inputs = [] + for slcs in patch_coords: + patch = image_tensor[slcs[0], slcs[1], slcs[2]] + pz = max(0, ps - patch.shape[0]) + py = max(0, ps - patch.shape[1]) + px_p = max(0, ps - patch.shape[2]) + if pz or py or px_p: + patch = F.pad(patch, (0, px_p, 0, py, 0, pz)) + lo, hi = patch.min(), patch.max() + patch = (patch - lo) / (hi - lo + 1e-8) if hi > lo else patch + model_inputs.append(patch.unsqueeze(0)) + t_patch = time.perf_counter() - t0 + + # ---- Inference ---- + t0 = time.perf_counter() + p0 = time.process_time() + predictions = [] + for i in range(0, len(model_inputs), batch_size): + chunk = torch.stack(model_inputs[i : i + batch_size]).to(device) + with torch.inference_mode(): + preds = model(chunk) + predictions.extend(p.cpu() for p in preds) + del chunk, preds + t_infer = time.perf_counter() - t0 + p_infer = time.process_time() - p0 + + # ---- Reconstruct ---- + t0 = time.perf_counter() + out = torch.zeros(vol_shape, dtype=torch.float32) + cnt = torch.zeros(vol_shape, dtype=torch.float32) + for pred, (zs, ys, xs) in zip(predictions, patch_coords): + out[zs, ys, xs] += pred[0] * gmap + cnt[zs, ys, xs] += gmap + out /= cnt.clamp(min=1e-8) + t_recon = time.perf_counter() - t0 + + # ---- Peak → world ---- + idx = np.unravel_index(np.argmax(out.numpy()), vol_shape) + pred_vox = tuple(int(c) for c in idx) + pred_world = fid_voxel2world(np.array(pred_vox, dtype=float), affine) + + t_total = t_load + t_patch + t_infer + t_recon + print( + f" [OK] AFID {fid:02d} " + f"load={t_load:.2f}s patch={t_patch:.2f}s " + f"infer={t_infer:.2f}s (cpu={p_infer:.2f}s) recon={t_recon:.2f}s total={t_total:.2f}s | " + f"pred_vox={pred_vox} " + f"world=[{pred_world[0]:.1f}, {pred_world[1]:.1f}, {pred_world[2]:.1f}]" + ) + return pred_world, out # out = full-volume probability map [D,H,W] + + +# =================================================================== +# Snakemake entry point +# =================================================================== + +fid = int(snakemake.wildcards.afid) # noqa: F821 + +cfg_block = snakemake.config.get("afids_inference", {}) # noqa: F821 +if not cfg_block: + raise ValueError("Missing 'afids_inference' block in snakebids.yml.") + +# Checkpoint path resolved by the rule via workflow.basedir +ckpt_path = snakemake.params.ckpt_path # noqa: F821 +if not Path(ckpt_path).exists(): + raise FileNotFoundError(f"Checkpoint for AFID {fid:02d} not found: {ckpt_path}") + +img = nib.nifti1.load(snakemake.input.t1w) # noqa: F821 + +pred_world, prob_map = infer_single_afid( # noqa: F821 + fid=fid, + ckpt_path=ckpt_path, + img=img, + prior_fcsv_path=snakemake.input.prior, # noqa: F821 + patch_size=cfg_block.get("patch_size", 64), + batch_size=snakemake.config.get("inference_batch_size") or 7, # noqa: F821 + device_str=cfg_block.get("device", "cpu"), +) + +# Write x y z to a plain-text coord file +Path(snakemake.output.coord).parent.mkdir(parents=True, exist_ok=True) # noqa: F821 +with open(snakemake.output.coord, "w") as f: # noqa: F821 + f.write(f"{pred_world[0]} {pred_world[1]} {pred_world[2]}\n") + +# Save probability map as NIfTI (same affine/header as input image) +Path(snakemake.output.prob).parent.mkdir(parents=True, exist_ok=True) # noqa: F821 +nib.save( # noqa: F821 + nib.Nifti1Image(prob_map.numpy(), img.affine, img.header), + snakemake.output.prob, # noqa: F821 +) diff --git a/autoafids/workflow/scripts/gather_afids.py b/autoafids/workflow/scripts/gather_afids.py new file mode 100644 index 0000000..cbe72a8 --- /dev/null +++ b/autoafids/workflow/scripts/gather_afids.py @@ -0,0 +1,107 @@ +# ruff: noqa +""" +gather_afids.py +=============== + +Snakemake script: collect all 32 per-AFID coord files and write the +combined FCSV. + +Called by the ``applyfidmodel_gather`` rule after all 32 +``applyfidmodel_single`` jobs have completed. + +Snakemake I/O +------------- + input: + coords – list of 32 plain-text coord files (x y z), one per AFID + output: + fcsv – combined Slicer-compatible FCSV with all 32 AFIDs +""" + +import csv +import warnings +from pathlib import Path +from typing import Dict + +import numpy as np + +warnings.filterwarnings("ignore") + + +# =================================================================== +# FCSV writer (identical to apply_with_prior.py) +# =================================================================== + +AFIDS_FIELDNAMES = [ + "id", + "x", + "y", + "z", + "ow", + "ox", + "oy", + "oz", + "vis", + "sel", + "lock", + "label", + "desc", + "associatedNodeID", +] + +FCSV_TEMPLATE = ( + Path(__file__).parent + / ".." + / ".." + / "resources" + / "tpl-MNI152NLin2009cAsym_res-01_T1w.fcsv" +) + + +def afids_to_fcsv(afid_coords: Dict[int, np.ndarray], fcsv_output) -> None: + with FCSV_TEMPLATE.open(encoding="utf-8", newline="") as fcsv_file: + header = [fcsv_file.readline() for _ in range(3)] + reader = csv.DictReader(fcsv_file, fieldnames=AFIDS_FIELDNAMES) + fcsv = list(reader) + + for idx, row in enumerate(fcsv): + label = idx + 1 + c = afid_coords.get(label, np.array([0.0, 0.0, 0.0])) + row["x"] = c[0] + row["y"] = c[1] + row["z"] = c[2] + + with Path(fcsv_output).open("w", encoding="utf-8", newline="") as out: + for line in header: + out.write(line) + writer = csv.DictWriter(out, fieldnames=AFIDS_FIELDNAMES) + for row in fcsv: + writer.writerow(row) + + +# =================================================================== +# Snakemake entry point +# =================================================================== + +# input.coords is a list of coord files sorted by afid wildcard order. +# We infer the AFID number from the filename (afid-XX in the path). +afid_coords: Dict[int, np.ndarray] = {} + +for coord_path in snakemake.input.coords: # noqa: F821 + path = Path(coord_path) + # Extract afid number from filename, e.g. "sub-001_afid-03_coord.txt" -> 3 + for part in path.stem.split("_"): + if part.startswith("afid-"): + fid = int(part.split("-")[1]) + break + else: + raise ValueError(f"Cannot parse AFID number from filename: {path.name}") + + with open(coord_path) as f: + x, y, z = map(float, f.read().strip().split()) + afid_coords[fid] = np.array([x, y, z]) + print(f" [gather] AFID {fid:02d} world=[{x:.1f}, {y:.1f}, {z:.1f}]") + +print(f"\n Writing combined FCSV with {len(afid_coords)} AFIDs...") +Path(snakemake.output.fcsv).parent.mkdir(parents=True, exist_ok=True) # noqa: F821 +afids_to_fcsv(afid_coords, snakemake.output.fcsv) # noqa: F821 +print(f" Done → {snakemake.output.fcsv}") # noqa: F821 diff --git a/autoafids/workflow/scripts/nnlm_to_fcsv.py b/autoafids/workflow/scripts/nnlm_to_fcsv.py new file mode 100644 index 0000000..8519b4c --- /dev/null +++ b/autoafids/workflow/scripts/nnlm_to_fcsv.py @@ -0,0 +1,85 @@ +"""nnlm_to_fcsv.py +Snakemake script: convert nnLM voxel-space JSON output to a Slicer +FCSV file with RAS world coordinates (mm). + +nnLM JSON convention +-------------------- + {"1": {"coordinates": [x, y, z], "likelihood": 0.87}, ...} + where x = column (SimpleITK axis 0), y = row, z = slice. + +NiBabel affine convention +-------------------------- + affine @ [i, j, k, 1] with i = slice-index, j = row, k = col + → maps (z_sitk, y_sitk, x_sitk) to RAS world coords. +""" + +import csv +import json +from pathlib import Path + +import nibabel as nib +import numpy as np + +# ── Load nnLM voxel-space coordinates ──────────────────────────────────────── +with open(snakemake.input.coords_json) as fh: + nnlm_raw = json.load(fh) # {"1": {"coordinates": [x,y,z], ...}, ...} + +# ── Load image affine (nibabel i,j,k = slice,row,col) ──────────────────────── +img = nib.load(snakemake.input.t1w) +affine = img.affine # 4×4 voxel→RAS matrix + +# ── Convert voxel → RAS world coords ───────────────────────────────────────── +# SimpleITK (x,y,z) == (col, row, slice) → nibabel (i,j,k) == (slice, row, col) +afid_world: dict[int, np.ndarray] = {} +for afid_str, entry in nnlm_raw.items(): + sx, sy, sz = entry["coordinates"] # SimpleITK x,y,z + # NiBabel expects (x, y, z) index mapping identical to SimpleITK ordering + nib_ijk = np.array([sx, sy, sz, 1.0]) + # 1) Get RAS world coordinates from the NIfTI affine (this matches true FCSV space) + world = (affine @ nib_ijk)[:3] + afid_world[int(afid_str)] = world + +# ── Read FCSV template (header + 32 landmark rows) ─────────────────────────── +fcsv_template = Path(snakemake.params.fcsv_template) + +FIELDNAMES = [ + "id", + "x", + "y", + "z", + "ow", + "ox", + "oy", + "oz", + "vis", + "sel", + "lock", + "label", + "desc", + "associatedNodeID", +] + +with fcsv_template.open(encoding="utf-8", newline="") as fh: + # Preserve the 3-line Slicer header verbatim + header_lines = [fh.readline() for _ in range(3)] + reader = csv.DictReader(fh, fieldnames=FIELDNAMES) + rows = list(reader) + +# ── Fill in predicted coords ────────────────────────────────────────────────── +for idx, row in enumerate(rows): + afid_label = idx + 1 + coords = afid_world.get(afid_label, np.zeros(3)) + row["x"] = f"{coords[0]:.6f}" + row["y"] = f"{coords[1]:.6f}" + row["z"] = f"{coords[2]:.6f}" + +# ── Write output FCSV ───────────────────────────────────────────────────────── +out_path = Path(snakemake.output.fcsv) +out_path.parent.mkdir(parents=True, exist_ok=True) + +with out_path.open("w", encoding="utf-8", newline="") as fh: + for line in header_lines: + fh.write(line) + writer = csv.DictWriter(fh, fieldnames=FIELDNAMES, extrasaction="ignore") + for row in rows: + writer.writerow(row)