diff --git a/autoafids/config/snakebids.yml b/autoafids/config/snakebids.yml index c10c814..01eefbc 100644 --- a/autoafids/config/snakebids.yml +++ b/autoafids/config/snakebids.yml @@ -287,35 +287,35 @@ afids_inference: afid_01: 'resources/afids_cnn_ckpts/afid-01_epoch-820_mae-0.0002.ckpt' afid_02: 'resources/afids_cnn_ckpts/afid-02_epoch-964_mae-0.0007.ckpt' afid_03: 'resources/afids_cnn_ckpts/afid-03_epoch-778_mae-0.0008.ckpt' - afid_04: 'resources/afids_cnn_ckpts/afid-01_epoch-820_mae-0.0002.ckpt' # placeholder - afid_05: 'resources/afids_cnn_ckpts/afid-01_epoch-820_mae-0.0002.ckpt' # placeholder - afid_06: 'resources/afids_cnn_ckpts/afid-01_epoch-820_mae-0.0002.ckpt' # placeholder - afid_07: 'resources/afids_cnn_ckpts/afid-01_epoch-820_mae-0.0002.ckpt' # placeholder - afid_08: 'resources/afids_cnn_ckpts/afid-01_epoch-820_mae-0.0002.ckpt' # placeholder - afid_09: 'resources/afids_cnn_ckpts/afid-01_epoch-820_mae-0.0002.ckpt' # placeholder - afid_10: 'resources/afids_cnn_ckpts/afid-01_epoch-820_mae-0.0002.ckpt' # placeholder - afid_11: 'resources/afids_cnn_ckpts/afid-01_epoch-820_mae-0.0002.ckpt' # placeholder - afid_12: 'resources/afids_cnn_ckpts/afid-01_epoch-820_mae-0.0002.ckpt' # placeholder - afid_13: 'resources/afids_cnn_ckpts/afid-01_epoch-820_mae-0.0002.ckpt' # placeholder - afid_14: 'resources/afids_cnn_ckpts/afid-01_epoch-820_mae-0.0002.ckpt' # placeholder - afid_15: 'resources/afids_cnn_ckpts/afid-01_epoch-820_mae-0.0002.ckpt' # placeholder - afid_16: 'resources/afids_cnn_ckpts/afid-01_epoch-820_mae-0.0002.ckpt' # placeholder - afid_17: 'resources/afids_cnn_ckpts/afid-01_epoch-820_mae-0.0002.ckpt' # placeholder - afid_18: 'resources/afids_cnn_ckpts/afid-01_epoch-820_mae-0.0002.ckpt' # placeholder - afid_19: 'resources/afids_cnn_ckpts/afid-01_epoch-820_mae-0.0002.ckpt' # placeholder - afid_20: 'resources/afids_cnn_ckpts/afid-01_epoch-820_mae-0.0002.ckpt' # placeholder - afid_21: 'resources/afids_cnn_ckpts/afid-01_epoch-820_mae-0.0002.ckpt' # placeholder - afid_22: 'resources/afids_cnn_ckpts/afid-01_epoch-820_mae-0.0002.ckpt' # placeholder - afid_23: 'resources/afids_cnn_ckpts/afid-01_epoch-820_mae-0.0002.ckpt' # placeholder - afid_24: 'resources/afids_cnn_ckpts/afid-01_epoch-820_mae-0.0002.ckpt' # placeholder - afid_25: 'resources/afids_cnn_ckpts/afid-01_epoch-820_mae-0.0002.ckpt' # placeholder - afid_26: 'resources/afids_cnn_ckpts/afid-01_epoch-820_mae-0.0002.ckpt' # placeholder - afid_27: 'resources/afids_cnn_ckpts/afid-01_epoch-820_mae-0.0002.ckpt' # placeholder - afid_28: 'resources/afids_cnn_ckpts/afid-01_epoch-820_mae-0.0002.ckpt' # placeholder - afid_29: 'resources/afids_cnn_ckpts/afid-01_epoch-820_mae-0.0002.ckpt' # placeholder - afid_30: 'resources/afids_cnn_ckpts/afid-01_epoch-820_mae-0.0002.ckpt' # placeholder - afid_31: 'resources/afids_cnn_ckpts/afid-01_epoch-820_mae-0.0002.ckpt' # placeholder - afid_32: 'resources/afids_cnn_ckpts/afid-01_epoch-820_mae-0.0002.ckpt' # placeholder + afid_04: 'resources/afids_cnn_ckpts/afid-04_epoch-529_mae-0.0007.ckpt' + afid_05: 'resources/afids_cnn_ckpts/afid-05_epoch-628_mae-0.0007.ckpt' + afid_06: 'resources/afids_cnn_ckpts/afid-06_epoch-485_mae-0.0009.ckpt' + afid_07: 'resources/afids_cnn_ckpts/afid-07_epoch-599_mae-0.0008.ckpt' + afid_08: 'resources/afids_cnn_ckpts/afid-08_epoch-455_mae-0.0013.ckpt' + afid_09: 'resources/afids_cnn_ckpts/afid-09_epoch-535_mae-0.0010.ckpt' + afid_10: 'resources/afids_cnn_ckpts/afid-10_epoch-431_mae-0.0009.ckpt' + afid_11: 'resources/afids_cnn_ckpts/afid-11_epoch-735_mae-0.0005.ckpt' + afid_12: 'resources/afids_cnn_ckpts/afid-12_epoch-694_mae-0.0007.ckpt' + afid_13: 'resources/afids_cnn_ckpts/afid-13_epoch-763_mae-0.0006.ckpt' + afid_14: 'resources/afids_cnn_ckpts/afid-14_epoch-405_mae-0.0011.ckpt' + afid_15: 'resources/afids_cnn_ckpts/afid-15_epoch-371_mae-0.0013.ckpt' + afid_16: 'resources/afids_cnn_ckpts/afid-16_epoch-454_mae-0.0012.ckpt' + afid_17: 'resources/afids_cnn_ckpts/afid-17_epoch-291_mae-0.0023.ckpt' + afid_18: 'resources/afids_cnn_ckpts/afid-18_epoch-233_mae-0.0016.ckpt' + afid_19: 'resources/afids_cnn_ckpts/afid-19_epoch-626_mae-0.0011.ckpt' + afid_20: 'resources/afids_cnn_ckpts/afid-20_epoch-720_mae-0.0011.ckpt' + afid_21: 'resources/afids_cnn_ckpts/afid-21_epoch-672_mae-0.0009.ckpt' + afid_22: 'resources/afids_cnn_ckpts/afid-22_epoch-647_mae-0.0008.ckpt' + afid_23: 'resources/afids_cnn_ckpts/afid-23_epoch-656_mae-0.0010.ckpt' + afid_24: 'resources/afids_cnn_ckpts/afid-24_epoch-689_mae-0.0007.ckpt' + afid_25: 'resources/afids_cnn_ckpts/afid-25_epoch-513_mae-0.0013.ckpt' + afid_26: 'resources/afids_cnn_ckpts/afid-26_epoch-784_mae-0.0009.ckpt' + afid_27: 'resources/afids_cnn_ckpts/afid-27_epoch-300_mae-0.0014.ckpt' + afid_28: 'resources/afids_cnn_ckpts/afid-28_epoch-351_mae-0.0009.ckpt' + afid_29: 'resources/afids_cnn_ckpts/afid-29_epoch-394_mae-0.0014.ckpt' + afid_30: 'resources/afids_cnn_ckpts/afid-30_epoch-618_mae-0.0012.ckpt' + afid_31: 'resources/afids_cnn_ckpts/afid-31_epoch-573_mae-0.0009.ckpt' + afid_32: 'resources/afids_cnn_ckpts/afid-32_epoch-561_mae-0.0007.ckpt' patch_size: 64 # cubic patch size in voxels device: cuda:0 # cpu recommended for parallel mode; cuda:0 for sequential overlap: 0.5 # sliding-window overlap for --detect_without_prior (0.0–0.75) diff --git a/autoafids/resources/afids_cnn_ckpts/afid-04_epoch-529_mae-0.0007.ckpt b/autoafids/resources/afids_cnn_ckpts/afid-04_epoch-529_mae-0.0007.ckpt new file mode 100755 index 0000000..8fde4fc Binary files /dev/null and b/autoafids/resources/afids_cnn_ckpts/afid-04_epoch-529_mae-0.0007.ckpt differ diff --git a/autoafids/resources/afids_cnn_ckpts/afid-05_epoch-628_mae-0.0007.ckpt b/autoafids/resources/afids_cnn_ckpts/afid-05_epoch-628_mae-0.0007.ckpt new file mode 100755 index 0000000..f23264d Binary files /dev/null and b/autoafids/resources/afids_cnn_ckpts/afid-05_epoch-628_mae-0.0007.ckpt differ diff --git a/autoafids/resources/afids_cnn_ckpts/afid-06_epoch-485_mae-0.0009.ckpt b/autoafids/resources/afids_cnn_ckpts/afid-06_epoch-485_mae-0.0009.ckpt new file mode 100755 index 0000000..7df451b Binary files /dev/null and b/autoafids/resources/afids_cnn_ckpts/afid-06_epoch-485_mae-0.0009.ckpt differ diff --git a/autoafids/resources/afids_cnn_ckpts/afid-07_epoch-599_mae-0.0008.ckpt b/autoafids/resources/afids_cnn_ckpts/afid-07_epoch-599_mae-0.0008.ckpt new file mode 100755 index 0000000..50345d7 Binary files /dev/null and b/autoafids/resources/afids_cnn_ckpts/afid-07_epoch-599_mae-0.0008.ckpt differ diff --git a/autoafids/resources/afids_cnn_ckpts/afid-08_epoch-455_mae-0.0013.ckpt b/autoafids/resources/afids_cnn_ckpts/afid-08_epoch-455_mae-0.0013.ckpt new file mode 100755 index 0000000..99e75c1 Binary files /dev/null and b/autoafids/resources/afids_cnn_ckpts/afid-08_epoch-455_mae-0.0013.ckpt differ diff --git a/autoafids/resources/afids_cnn_ckpts/afid-09_epoch-535_mae-0.0010.ckpt b/autoafids/resources/afids_cnn_ckpts/afid-09_epoch-535_mae-0.0010.ckpt new file mode 100755 index 0000000..116f2a8 Binary files /dev/null and b/autoafids/resources/afids_cnn_ckpts/afid-09_epoch-535_mae-0.0010.ckpt differ diff --git a/autoafids/resources/afids_cnn_ckpts/afid-10_epoch-431_mae-0.0009.ckpt b/autoafids/resources/afids_cnn_ckpts/afid-10_epoch-431_mae-0.0009.ckpt new file mode 100755 index 0000000..6df3231 Binary files /dev/null and b/autoafids/resources/afids_cnn_ckpts/afid-10_epoch-431_mae-0.0009.ckpt differ diff --git a/autoafids/resources/afids_cnn_ckpts/afid-11_epoch-735_mae-0.0005.ckpt b/autoafids/resources/afids_cnn_ckpts/afid-11_epoch-735_mae-0.0005.ckpt new file mode 100755 index 0000000..83d3e75 Binary files /dev/null and b/autoafids/resources/afids_cnn_ckpts/afid-11_epoch-735_mae-0.0005.ckpt differ diff --git a/autoafids/resources/afids_cnn_ckpts/afid-12_epoch-694_mae-0.0007.ckpt b/autoafids/resources/afids_cnn_ckpts/afid-12_epoch-694_mae-0.0007.ckpt new file mode 100755 index 0000000..6cf7e0e Binary files /dev/null and b/autoafids/resources/afids_cnn_ckpts/afid-12_epoch-694_mae-0.0007.ckpt differ diff --git a/autoafids/resources/afids_cnn_ckpts/afid-13_epoch-763_mae-0.0006.ckpt b/autoafids/resources/afids_cnn_ckpts/afid-13_epoch-763_mae-0.0006.ckpt new file mode 100755 index 0000000..1a04256 Binary files /dev/null and b/autoafids/resources/afids_cnn_ckpts/afid-13_epoch-763_mae-0.0006.ckpt differ diff --git a/autoafids/resources/afids_cnn_ckpts/afid-14_epoch-405_mae-0.0011.ckpt b/autoafids/resources/afids_cnn_ckpts/afid-14_epoch-405_mae-0.0011.ckpt new file mode 100755 index 0000000..6da8a90 Binary files /dev/null and b/autoafids/resources/afids_cnn_ckpts/afid-14_epoch-405_mae-0.0011.ckpt differ diff --git a/autoafids/resources/afids_cnn_ckpts/afid-15_epoch-371_mae-0.0013.ckpt b/autoafids/resources/afids_cnn_ckpts/afid-15_epoch-371_mae-0.0013.ckpt new file mode 100755 index 0000000..5ea0d3c Binary files /dev/null and b/autoafids/resources/afids_cnn_ckpts/afid-15_epoch-371_mae-0.0013.ckpt differ diff --git a/autoafids/resources/afids_cnn_ckpts/afid-16_epoch-454_mae-0.0012.ckpt b/autoafids/resources/afids_cnn_ckpts/afid-16_epoch-454_mae-0.0012.ckpt new file mode 100755 index 0000000..9383ad6 Binary files /dev/null and b/autoafids/resources/afids_cnn_ckpts/afid-16_epoch-454_mae-0.0012.ckpt differ diff --git a/autoafids/resources/afids_cnn_ckpts/afid-17_epoch-291_mae-0.0023.ckpt b/autoafids/resources/afids_cnn_ckpts/afid-17_epoch-291_mae-0.0023.ckpt new file mode 100755 index 0000000..b30b6c6 Binary files /dev/null and b/autoafids/resources/afids_cnn_ckpts/afid-17_epoch-291_mae-0.0023.ckpt differ diff --git a/autoafids/resources/afids_cnn_ckpts/afid-18_epoch-233_mae-0.0016.ckpt b/autoafids/resources/afids_cnn_ckpts/afid-18_epoch-233_mae-0.0016.ckpt new file mode 100755 index 0000000..d4e586a Binary files /dev/null and b/autoafids/resources/afids_cnn_ckpts/afid-18_epoch-233_mae-0.0016.ckpt differ diff --git a/autoafids/resources/afids_cnn_ckpts/afid-19_epoch-626_mae-0.0011.ckpt b/autoafids/resources/afids_cnn_ckpts/afid-19_epoch-626_mae-0.0011.ckpt new file mode 100755 index 0000000..8f612b4 Binary files /dev/null and b/autoafids/resources/afids_cnn_ckpts/afid-19_epoch-626_mae-0.0011.ckpt differ diff --git a/autoafids/resources/afids_cnn_ckpts/afid-20_epoch-720_mae-0.0011.ckpt b/autoafids/resources/afids_cnn_ckpts/afid-20_epoch-720_mae-0.0011.ckpt new file mode 100755 index 0000000..99c20d5 Binary files /dev/null and b/autoafids/resources/afids_cnn_ckpts/afid-20_epoch-720_mae-0.0011.ckpt differ diff --git a/autoafids/resources/afids_cnn_ckpts/afid-21_epoch-672_mae-0.0009.ckpt b/autoafids/resources/afids_cnn_ckpts/afid-21_epoch-672_mae-0.0009.ckpt new file mode 100755 index 0000000..8106543 Binary files /dev/null and b/autoafids/resources/afids_cnn_ckpts/afid-21_epoch-672_mae-0.0009.ckpt differ diff --git a/autoafids/resources/afids_cnn_ckpts/afid-22_epoch-647_mae-0.0008.ckpt b/autoafids/resources/afids_cnn_ckpts/afid-22_epoch-647_mae-0.0008.ckpt new file mode 100755 index 0000000..094835b Binary files /dev/null and b/autoafids/resources/afids_cnn_ckpts/afid-22_epoch-647_mae-0.0008.ckpt differ diff --git a/autoafids/resources/afids_cnn_ckpts/afid-23_epoch-656_mae-0.0010.ckpt b/autoafids/resources/afids_cnn_ckpts/afid-23_epoch-656_mae-0.0010.ckpt new file mode 100755 index 0000000..6a639f2 Binary files /dev/null and b/autoafids/resources/afids_cnn_ckpts/afid-23_epoch-656_mae-0.0010.ckpt differ diff --git a/autoafids/resources/afids_cnn_ckpts/afid-24_epoch-689_mae-0.0007.ckpt b/autoafids/resources/afids_cnn_ckpts/afid-24_epoch-689_mae-0.0007.ckpt new file mode 100755 index 0000000..88b5491 Binary files /dev/null and b/autoafids/resources/afids_cnn_ckpts/afid-24_epoch-689_mae-0.0007.ckpt differ diff --git a/autoafids/resources/afids_cnn_ckpts/afid-25_epoch-513_mae-0.0013.ckpt b/autoafids/resources/afids_cnn_ckpts/afid-25_epoch-513_mae-0.0013.ckpt new file mode 100755 index 0000000..6d47c82 Binary files /dev/null and b/autoafids/resources/afids_cnn_ckpts/afid-25_epoch-513_mae-0.0013.ckpt differ diff --git a/autoafids/resources/afids_cnn_ckpts/afid-26_epoch-784_mae-0.0009.ckpt b/autoafids/resources/afids_cnn_ckpts/afid-26_epoch-784_mae-0.0009.ckpt new file mode 100755 index 0000000..417ca96 Binary files /dev/null and b/autoafids/resources/afids_cnn_ckpts/afid-26_epoch-784_mae-0.0009.ckpt differ diff --git a/autoafids/resources/afids_cnn_ckpts/afid-27_epoch-300_mae-0.0014.ckpt b/autoafids/resources/afids_cnn_ckpts/afid-27_epoch-300_mae-0.0014.ckpt new file mode 100755 index 0000000..a0e2cf7 Binary files /dev/null and b/autoafids/resources/afids_cnn_ckpts/afid-27_epoch-300_mae-0.0014.ckpt differ diff --git a/autoafids/resources/afids_cnn_ckpts/afid-28_epoch-351_mae-0.0009.ckpt b/autoafids/resources/afids_cnn_ckpts/afid-28_epoch-351_mae-0.0009.ckpt new file mode 100755 index 0000000..b10247c Binary files /dev/null and b/autoafids/resources/afids_cnn_ckpts/afid-28_epoch-351_mae-0.0009.ckpt differ diff --git a/autoafids/resources/afids_cnn_ckpts/afid-29_epoch-394_mae-0.0014.ckpt b/autoafids/resources/afids_cnn_ckpts/afid-29_epoch-394_mae-0.0014.ckpt new file mode 100755 index 0000000..4d147fd Binary files /dev/null and b/autoafids/resources/afids_cnn_ckpts/afid-29_epoch-394_mae-0.0014.ckpt differ diff --git a/autoafids/resources/afids_cnn_ckpts/afid-30_epoch-618_mae-0.0012.ckpt b/autoafids/resources/afids_cnn_ckpts/afid-30_epoch-618_mae-0.0012.ckpt new file mode 100755 index 0000000..4c7124a Binary files /dev/null and b/autoafids/resources/afids_cnn_ckpts/afid-30_epoch-618_mae-0.0012.ckpt differ diff --git a/autoafids/resources/afids_cnn_ckpts/afid-31_epoch-573_mae-0.0009.ckpt b/autoafids/resources/afids_cnn_ckpts/afid-31_epoch-573_mae-0.0009.ckpt new file mode 100755 index 0000000..531d1be Binary files /dev/null and b/autoafids/resources/afids_cnn_ckpts/afid-31_epoch-573_mae-0.0009.ckpt differ diff --git a/autoafids/resources/afids_cnn_ckpts/afid-32_epoch-561_mae-0.0007.ckpt b/autoafids/resources/afids_cnn_ckpts/afid-32_epoch-561_mae-0.0007.ckpt new file mode 100755 index 0000000..2e7b577 Binary files /dev/null and b/autoafids/resources/afids_cnn_ckpts/afid-32_epoch-561_mae-0.0007.ckpt differ diff --git a/autoafids/workflow/Snakefile b/autoafids/workflow/Snakefile index 800f98b..164bf1d 100644 --- a/autoafids/workflow/Snakefile +++ b/autoafids/workflow/Snakefile @@ -355,13 +355,16 @@ rule mni2subfids: # ── Detection mode: choose rules to include ────────────────────────────── if config.get("detect_mode", "prior") == "nnlm": + # nnLandmark mode: single-pass whole-volume detection include: "rules/nnlm.smk" + else: # Legacy CNN modes (prior / noprior) # If running on GPU, automatically enable sequential inference to avoid spawning 32 parallel jobs if config.get("afids_inference", {}).get("device") == "cuda:0": config["enable_sequential_inference"] = True + include: "rules/cnn.smk" @@ -377,7 +380,6 @@ if config["LEAD_DBS_DIR"] or config["FMRIPREP_DIR"]: include: "rules/regqc.smk" - # Select the correct FCSV output descriptor based on inference mode. # prior → applyfidmodel_gather → desc="afidscnn" # noprior → applyfidmodel_noprior_gather → desc="afidscnn-noprior" @@ -445,4 +447,3 @@ rule all: else [] ), default_target: True - diff --git a/autoafids/workflow/rules/cnn.smk b/autoafids/workflow/rules/cnn.smk index 7112914..635e73e 100644 --- a/autoafids/workflow/rules/cnn.smk +++ b/autoafids/workflow/rules/cnn.smk @@ -1,5 +1,6 @@ # populate the AUTOAFIDS_CACHE_DIR folder as needed + rule download_cnn_model: params: url=config["resource_urls"][config["model"]], @@ -15,9 +16,11 @@ rule download_cnn_model: " unzip -q -d {output.unzip_dir} model.zip && " " rm model.zip" + _AFIDS = [f"{i:02d}" for i in range(1, 33)] # ["01", "02", ..., "32"] if config.get("enable_sequential_inference", False): + rule applyfidmodel_all: """Run 5-patch inference for ALL 32 AFIDs using MNI-registered prior location""" input: @@ -27,7 +30,7 @@ if config.get("enable_sequential_inference", False): datatype="normalize", desc=chosen_norm_method, suffix="T1w.nii.gz", - **inputs[config["modality"]].wildcards + **inputs[config["modality"]].wildcards, ) if config["modality"] != "T1w" else bids( @@ -36,7 +39,7 @@ if config.get("enable_sequential_inference", False): desc=chosen_norm_method, res=config["res"], suffix="T1w.nii.gz", - **inputs[config["modality"]].wildcards + **inputs[config["modality"]].wildcards, ) ), prior=bids( @@ -55,7 +58,8 @@ if config.get("enable_sequential_inference", False): afid=afid, suffix="coord.txt", **inputs[config["modality"]].wildcards, - ) for afid in _AFIDS + ) + for afid in _AFIDS ], probs=[ bids( @@ -64,7 +68,8 @@ if config.get("enable_sequential_inference", False): afid=afid, suffix="probmap.nii.gz", **inputs[config["modality"]].wildcards, - ) for afid in _AFIDS + ) + for afid in _AFIDS ], fcsv=bids( root=root, @@ -83,14 +88,15 @@ if config.get("enable_sequential_inference", False): ckpts=lambda wildcards: { key: str(Path(workflow.basedir).parent / path) for key, path in config["afids_inference"]["checkpoints"].items() - } + }, threads: 1 - conda: "../envs/pytorch.yaml" script: "../scripts/apply_with_prior_all.py" + else: + # ── WITH PRIOR (SINGLE) ────────────────────────────────────────────────────── rule applyfidmodel_single: """Run 5-patch inference for ONE AFID using MNI-registered prior location.""" @@ -101,7 +107,7 @@ else: datatype="normalize", desc=chosen_norm_method, suffix="T1w.nii.gz", - **inputs[config["modality"]].wildcards + **inputs[config["modality"]].wildcards, ) if config["modality"] != "T1w" else bids( @@ -110,7 +116,7 @@ else: desc=chosen_norm_method, res=config["res"], suffix="T1w.nii.gz", - **inputs[config["modality"]].wildcards + **inputs[config["modality"]].wildcards, ) ), prior=bids( @@ -148,10 +154,11 @@ else: params: ckpt_path=lambda wildcards: str( Path(workflow.basedir).parent - / config["afids_inference"]["checkpoints"][f"afid_{int(wildcards.afid):02d}"] + / config["afids_inference"]["checkpoints"][ + f"afid_{int(wildcards.afid):02d}" + ] ), threads: 1 - conda: "../envs/pytorch.yaml" script: @@ -166,7 +173,10 @@ else: datatype="afids-cnn", afid="{afid}", suffix="coord.txt", - **{k: getattr(wildcards, k) for k in inputs[config["modality"]].wildcards}, + **{ + k: getattr(wildcards, k) + for k in inputs[config["modality"]].wildcards + }, ), afid=_AFIDS, ), @@ -190,9 +200,11 @@ else: script: "../scripts/gather_afids.py" + # --------------------------------------------------------------------------- if config.get("enable_sequential_inference", False): + rule applyfidmodel_noprior_all: """Whole-volume sliding-window inference for ALL 32 AFIDs (no prior needed)""" input: @@ -202,7 +214,7 @@ if config.get("enable_sequential_inference", False): datatype="normalize", desc=chosen_norm_method, suffix="T1w.nii.gz", - **inputs[config["modality"]].wildcards + **inputs[config["modality"]].wildcards, ) if config["modality"] != "T1w" else bids( @@ -211,7 +223,7 @@ if config.get("enable_sequential_inference", False): desc=chosen_norm_method, res=config["res"], suffix="T1w.nii.gz", - **inputs[config["modality"]].wildcards + **inputs[config["modality"]].wildcards, ) ), output: @@ -222,7 +234,8 @@ if config.get("enable_sequential_inference", False): afid=afid, suffix="coord.txt", **inputs[config["modality"]].wildcards, - ) for afid in _AFIDS + ) + for afid in _AFIDS ], probs=[ bids( @@ -231,7 +244,8 @@ if config.get("enable_sequential_inference", False): afid=afid, suffix="probmap.nii.gz", **inputs[config["modality"]].wildcards, - ) for afid in _AFIDS + ) + for afid in _AFIDS ], fcsv=bids( root=root, @@ -250,13 +264,15 @@ if config.get("enable_sequential_inference", False): ckpts=lambda wildcards: { key: str(Path(workflow.basedir).parent / path) for key, path in config["afids_inference"]["checkpoints"].items() - } + }, threads: 1 conda: "../envs/pytorch.yaml" script: "../scripts/apply_noprior_all.py" + else: + # ── WITHOUT PRIOR (SINGLE) ────────────────────────────────────────────────── rule applyfidmodel_noprior_single: """Whole-volume sliding-window inference for ONE AFID — no prior needed.""" @@ -267,7 +283,7 @@ else: datatype="normalize", desc=chosen_norm_method, suffix="T1w.nii.gz", - **inputs[config["modality"]].wildcards + **inputs[config["modality"]].wildcards, ) if config["modality"] != "T1w" else bids( @@ -276,7 +292,7 @@ else: desc=chosen_norm_method, res=config["res"], suffix="T1w.nii.gz", - **inputs[config["modality"]].wildcards + **inputs[config["modality"]].wildcards, ) ), output: @@ -306,7 +322,9 @@ else: params: ckpt_path=lambda wildcards: str( Path(workflow.basedir).parent - / config["afids_inference"]["checkpoints"][f"afid_{int(wildcards.afid):02d}"] + / config["afids_inference"]["checkpoints"][ + f"afid_{int(wildcards.afid):02d}" + ] ), threads: 1 conda: @@ -323,7 +341,10 @@ else: datatype="afids-cnn-noprior", afid="{afid}", suffix="coord.txt", - **{k: getattr(wildcards, k) for k in inputs[config["modality"]].wildcards}, + **{ + k: getattr(wildcards, k) + for k in inputs[config["modality"]].wildcards + }, ), afid=_AFIDS, ), @@ -345,4 +366,4 @@ else: conda: "../envs/pytorch.yaml" script: - "../scripts/gather_afids.py" \ No newline at end of file + "../scripts/gather_afids.py" diff --git a/autoafids/workflow/rules/nnlm.smk b/autoafids/workflow/rules/nnlm.smk index 7dc9121..fdbddff 100644 --- a/autoafids/workflow/rules/nnlm.smk +++ b/autoafids/workflow/rules/nnlm.smk @@ -147,7 +147,7 @@ rule nnlm_to_fcsv: **inputs[config["modality"]].wildcards, ), params: - fcsv_template=str(Path(workflow.basedir).parent / "resources" / "dummy.fcsv") + fcsv_template=str(Path(workflow.basedir).parent / "resources" / "dummy.fcsv"), conda: "../envs/nibabel.yaml" script: diff --git a/autoafids/workflow/scripts/apply_noprior_all.py b/autoafids/workflow/scripts/apply_noprior_all.py index 3887848..185ea17 100644 --- a/autoafids/workflow/scripts/apply_noprior_all.py +++ b/autoafids/workflow/scripts/apply_noprior_all.py @@ -53,6 +53,7 @@ # FCSV helpers # =================================================================== + def fid_voxel2world(fid_voxel: NDArray, nii_affine: NDArray) -> NDArray: return (nii_affine[:3, :3].dot(fid_voxel) + nii_affine[:3, 3]).astype(float) @@ -62,6 +63,7 @@ def fid_voxel2world(fid_voxel: NDArray, nii_affine: NDArray) -> NDArray: # (encoder_blocks, decoder_blocks, deep_supervision_heads) # =================================================================== + class ConvBlock(nn.Module): def __init__(self, in_channels: int, out_channels: int) -> None: super().__init__() @@ -73,14 +75,19 @@ def __init__(self, in_channels: int, out_channels: int) -> None: nn.InstanceNorm3d(out_channels, affine=True), nn.LeakyReLU(negative_slope=0.01, inplace=True), ) - def forward(self, x): return self.conv_block(x) + + def forward(self, x): + return self.conv_block(x) class EncoderBlock(nn.Module): def __init__(self, in_channels: int, out_channels: int) -> None: super().__init__() self.conv_block = ConvBlock(in_channels, out_channels) - self.downsample = nn.Conv3d(out_channels, out_channels, kernel_size=3, stride=2, padding=1) + self.downsample = nn.Conv3d( + out_channels, out_channels, kernel_size=3, stride=2, padding=1 + ) + def forward(self, x): skip = self.conv_block(x) return skip, self.downsample(skip) @@ -89,14 +96,19 @@ def forward(self, x): class DecoderBlock(nn.Module): def __init__(self, in_channels: int, out_channels: int) -> None: super().__init__() - self.upsample = nn.ConvTranspose3d(in_channels, in_channels // 2, kernel_size=2, stride=2) + self.upsample = nn.ConvTranspose3d( + in_channels, in_channels // 2, kernel_size=2, stride=2 + ) self.conv_block = ConvBlock(in_channels, out_channels) - def forward(self, x, skip): return self.conv_block(torch.cat([self.upsample(x), skip], 1)) + + def forward(self, x, skip): + return self.conv_block(torch.cat([self.upsample(x), skip], 1)) class nnUNet(nn.Module): - def __init__(self, in_channels: int, out_channels: int, - features: Optional[List[int]] = None) -> None: + def __init__( + self, in_channels: int, out_channels: int, features: Optional[List[int]] = None + ) -> None: super().__init__() if features is None: features = [32, 64, 128, 256, 320] @@ -112,12 +124,15 @@ def __init__(self, in_channels: int, out_channels: int, self.decoder_blocks.append(DecoderBlock(rev[i], rev[i + 1])) self.deep_supervision_heads = nn.ModuleList() for feat in reversed(features[:-1]): - self.deep_supervision_heads.append(nn.Conv3d(feat, out_channels, kernel_size=1)) + self.deep_supervision_heads.append( + nn.Conv3d(feat, out_channels, kernel_size=1) + ) def forward(self, x): skips = [] for enc in self.encoder_blocks: - skip, x = enc(x); skips.append(skip) + skip, x = enc(x) + skips.append(skip) x = self.bottleneck(x) skips = list(reversed(skips)) for i, dec in enumerate(self.decoder_blocks): @@ -126,12 +141,18 @@ def forward(self, x): class nnUNet_VanillaUNet(nnUNet): - def __init__(self, in_channels: int = 1, out_channels: int = 1, - features: Optional[List[int]] = None) -> None: + def __init__( + self, + in_channels: int = 1, + out_channels: int = 1, + features: Optional[List[int]] = None, + ) -> None: super().__init__(in_channels, out_channels, features) -def _load_model(ckpt_path: str, features: List[int], device: torch.device) -> nnUNet_VanillaUNet: +def _load_model( + ckpt_path: str, features: List[int], device: torch.device +) -> nnUNet_VanillaUNet: model = nnUNet_VanillaUNet(in_channels=1, out_channels=1, features=features) ckpt = torch.load(ckpt_path, map_location="cpu") raw_sd = ckpt.get("state_dict", ckpt) @@ -147,6 +168,7 @@ def _load_model(ckpt_path: str, features: List[int], device: torch.device) -> nn # Sliding-window inference (no prior) # =================================================================== + def infer_noprior_single_afid( fid: int, ckpt_path: str, @@ -160,7 +182,9 @@ def infer_noprior_single_afid( if features is None: features = [16, 32, 64] - device = torch.device(device_str if (device_str == "cpu" or torch.cuda.is_available()) else "cpu") + device = torch.device( + device_str if (device_str == "cpu" or torch.cuda.is_available()) else "cpu" + ) # ---- Image ---- image_np = img.get_fdata().astype(np.float32) @@ -181,10 +205,13 @@ def infer_noprior_single_afid( # ---- Gaussian map ---- ps = patch_size + def _g1d(n): - s = n * 0.125; c = n // 2 + s = n * 0.125 + c = n // 2 x = torch.arange(n, dtype=torch.float32) - return torch.exp(-((x - c) ** 2) / (2 * s ** 2)) + return torch.exp(-((x - c) ** 2) / (2 * s**2)) + gmap = torch.einsum("z,y,x->zyx", _g1d(ps), _g1d(ps), _g1d(ps)) gmap = gmap / gmap.max() @@ -195,8 +222,12 @@ def _g1d(n): for z in range(0, D, step): for y in range(0, H, step): for x in range(0, W, step): - ze = min(z + ps, D); ye = min(y + ps, H); xe = min(x + ps, W) - zs = max(0, ze - ps); ys = max(0, ye - ps); xs = max(0, xe - ps) + ze = min(z + ps, D) + ye = min(y + ps, H) + xe = min(x + ps, W) + zs = max(0, ze - ps) + ys = max(0, ye - ps) + xs = max(0, xe - ps) patch_coords.append((slice(zs, ze), slice(ys, ye), slice(xs, xe))) num_patches = len(patch_coords) @@ -219,7 +250,7 @@ def _g1d(n): t0 = time.perf_counter() predictions = [] for i in range(0, len(model_inputs), batch_size): - chunk = torch.stack(model_inputs[i: i + batch_size]).to(device) + chunk = torch.stack(model_inputs[i : i + batch_size]).to(device) with torch.inference_mode(): preds = model(chunk) predictions.extend(p.cpu() for p in preds) @@ -241,7 +272,7 @@ def _g1d(n): zo = max(0, (ps - az) // 2) yo = max(0, (ps - ay) // 2) xo = max(0, (ps - ax) // 2) - w = gmap[zo: zo + az, yo: yo + ay, xo: xo + ax] + w = gmap[zo : zo + az, yo : yo + ay, xo : xo + ax] pc = pred[0, :az, :ay, :ax] out[zs, ys, xs] += pc * w cnt[zs, ys, xs] += w @@ -265,7 +296,6 @@ def _g1d(n): return pred_world, out # out = full-volume probability map [D,H,W] - # =================================================================== # Snakemake entry point (All 32 AFIDs) # =================================================================== @@ -275,25 +305,58 @@ def _g1d(n): import csv AFIDS_FIELDNAMES = [ - "id", "x", "y", "z", "ow", "ox", "oy", "oz", "vis", "sel", "lock", - "label", "desc", "associatedNodeID", + "id", + "x", + "y", + "z", + "ow", + "ox", + "oy", + "oz", + "vis", + "sel", + "lock", + "label", + "desc", + "associatedNodeID", ] AFID_DESCRIPTIONS = [ - "AC", "PC", "Infracollicular Sulcus", "PMJ", - "Superior IPF", "Right Superior LMS", "Left Superior LMS", - "Right Inferior LMS", "Left Inferior LMS", "Culmen", - "Intermammillary Sulcus", "Right Mammilary Body", "Left Mammilary Body", - "Pineal Gland", "Right LV at AC", "Left LV at AC", - "Right LV at PC", "Left LV at PC", "Genu of CC", "Splenium of CC", - "Right AL Temporal Horn", "Left AL Tempral Horn", - "R. Sup. AM Temporal Horn", "L. Sup. AM Temporal Horn", - "R Inf. AM Temp Horn", "L Inf. AM Temp Horn", - "Right IG Origin", "Left IG Origin", - "R Ventral Occipital Horn", "L Ventral Occipital Horn", - "R Olfactory Fundus", "L Olfactory Fundus", + "AC", + "PC", + "Infracollicular Sulcus", + "PMJ", + "Superior IPF", + "Right Superior LMS", + "Left Superior LMS", + "Right Inferior LMS", + "Left Inferior LMS", + "Culmen", + "Intermammillary Sulcus", + "Right Mammilary Body", + "Left Mammilary Body", + "Pineal Gland", + "Right LV at AC", + "Left LV at AC", + "Right LV at PC", + "Left LV at PC", + "Genu of CC", + "Splenium of CC", + "Right AL Temporal Horn", + "Left AL Tempral Horn", + "R. Sup. AM Temporal Horn", + "L. Sup. AM Temporal Horn", + "R Inf. AM Temp Horn", + "L Inf. AM Temp Horn", + "Right IG Origin", + "Left IG Origin", + "R Ventral Occipital Horn", + "L Ventral Occipital Horn", + "R Olfactory Fundus", + "L Olfactory Fundus", ] + def write_combined_fcsv(afid_coords, fcsv_output): header_lines = [ "# Markups fiducial file version = 4.10\n", @@ -304,12 +367,24 @@ def write_combined_fcsv(afid_coords, fcsv_output): for lbl in range(1, 33): c = afid_coords.get(lbl, np.array([0.0, 0.0, 0.0])) desc = AFID_DESCRIPTIONS[lbl - 1] if lbl <= len(AFID_DESCRIPTIONS) else "" - rows.append({ - "id": lbl, "x": c[0], "y": c[1], "z": c[2], - "ow": "0.000", "ox": "0.000", "oy": "0.000", "oz": "1.000", - "vis": 1, "sel": 1, "lock": 1, - "label": lbl, "desc": desc, "associatedNodeID": "", - }) + rows.append( + { + "id": lbl, + "x": c[0], + "y": c[1], + "z": c[2], + "ow": "0.000", + "ox": "0.000", + "oy": "0.000", + "oz": "1.000", + "vis": 1, + "sel": 1, + "lock": 1, + "label": lbl, + "desc": desc, + "associatedNodeID": "", + } + ) with Path(fcsv_output).open("w", encoding="utf-8", newline="") as f: for line in header_lines: f.write(line) @@ -317,6 +392,7 @@ def write_combined_fcsv(afid_coords, fcsv_output): for row in rows: writer.writerow(row) + cfg_block = snakemake.config.get("afids_inference", {}) if not cfg_block: raise ValueError("Missing 'afids_inference' block in snakebids.yml.") @@ -341,13 +417,17 @@ def write_combined_fcsv(afid_coords, fcsv_output): continue pred_world, prob_map = infer_noprior_single_afid( - fid = fid, - ckpt_path = ckpt_path, - img = img, - overlap = snakemake.config["inference_overlap"] if snakemake.config.get("inference_overlap") is not None else cfg_block.get("overlap", 0.5), - patch_size = cfg_block.get("patch_size", 64), - batch_size = snakemake.config.get("inference_batch_size") or 7, - device_str = cfg_block.get("device", "cpu"), + fid=fid, + ckpt_path=ckpt_path, + img=img, + overlap=( + snakemake.config["inference_overlap"] + if snakemake.config.get("inference_overlap") is not None + else cfg_block.get("overlap", 0.5) + ), + patch_size=cfg_block.get("patch_size", 64), + batch_size=snakemake.config.get("inference_batch_size") or 7, + device_str=cfg_block.get("device", "cpu"), ) combined_coords[fid] = pred_world diff --git a/autoafids/workflow/scripts/apply_noprior_single.py b/autoafids/workflow/scripts/apply_noprior_single.py index 7467774..ec9a0cf 100644 --- a/autoafids/workflow/scripts/apply_noprior_single.py +++ b/autoafids/workflow/scripts/apply_noprior_single.py @@ -54,6 +54,7 @@ # FCSV helpers # =================================================================== + def fid_voxel2world(fid_voxel: NDArray, nii_affine: NDArray) -> NDArray: return (nii_affine[:3, :3].dot(fid_voxel) + nii_affine[:3, 3]).astype(float) @@ -63,6 +64,7 @@ def fid_voxel2world(fid_voxel: NDArray, nii_affine: NDArray) -> NDArray: # (encoder_blocks, decoder_blocks, deep_supervision_heads) # =================================================================== + class ConvBlock(nn.Module): def __init__(self, in_channels: int, out_channels: int) -> None: super().__init__() @@ -74,14 +76,19 @@ def __init__(self, in_channels: int, out_channels: int) -> None: nn.InstanceNorm3d(out_channels, affine=True), nn.LeakyReLU(negative_slope=0.01, inplace=True), ) - def forward(self, x): return self.conv_block(x) + + def forward(self, x): + return self.conv_block(x) class EncoderBlock(nn.Module): def __init__(self, in_channels: int, out_channels: int) -> None: super().__init__() self.conv_block = ConvBlock(in_channels, out_channels) - self.downsample = nn.Conv3d(out_channels, out_channels, kernel_size=3, stride=2, padding=1) + self.downsample = nn.Conv3d( + out_channels, out_channels, kernel_size=3, stride=2, padding=1 + ) + def forward(self, x): skip = self.conv_block(x) return skip, self.downsample(skip) @@ -90,14 +97,19 @@ def forward(self, x): class DecoderBlock(nn.Module): def __init__(self, in_channels: int, out_channels: int) -> None: super().__init__() - self.upsample = nn.ConvTranspose3d(in_channels, in_channels // 2, kernel_size=2, stride=2) + self.upsample = nn.ConvTranspose3d( + in_channels, in_channels // 2, kernel_size=2, stride=2 + ) self.conv_block = ConvBlock(in_channels, out_channels) - def forward(self, x, skip): return self.conv_block(torch.cat([self.upsample(x), skip], 1)) + + def forward(self, x, skip): + return self.conv_block(torch.cat([self.upsample(x), skip], 1)) class nnUNet(nn.Module): - def __init__(self, in_channels: int, out_channels: int, - features: Optional[List[int]] = None) -> None: + def __init__( + self, in_channels: int, out_channels: int, features: Optional[List[int]] = None + ) -> None: super().__init__() if features is None: features = [32, 64, 128, 256, 320] @@ -113,12 +125,15 @@ def __init__(self, in_channels: int, out_channels: int, self.decoder_blocks.append(DecoderBlock(rev[i], rev[i + 1])) self.deep_supervision_heads = nn.ModuleList() for feat in reversed(features[:-1]): - self.deep_supervision_heads.append(nn.Conv3d(feat, out_channels, kernel_size=1)) + self.deep_supervision_heads.append( + nn.Conv3d(feat, out_channels, kernel_size=1) + ) def forward(self, x): skips = [] for enc in self.encoder_blocks: - skip, x = enc(x); skips.append(skip) + skip, x = enc(x) + skips.append(skip) x = self.bottleneck(x) skips = list(reversed(skips)) for i, dec in enumerate(self.decoder_blocks): @@ -127,12 +142,18 @@ def forward(self, x): class nnUNet_VanillaUNet(nnUNet): - def __init__(self, in_channels: int = 1, out_channels: int = 1, - features: Optional[List[int]] = None) -> None: + def __init__( + self, + in_channels: int = 1, + out_channels: int = 1, + features: Optional[List[int]] = None, + ) -> None: super().__init__(in_channels, out_channels, features) -def _load_model(ckpt_path: str, features: List[int], device: torch.device) -> nnUNet_VanillaUNet: +def _load_model( + ckpt_path: str, features: List[int], device: torch.device +) -> nnUNet_VanillaUNet: model = nnUNet_VanillaUNet(in_channels=1, out_channels=1, features=features) ckpt = torch.load(ckpt_path, map_location="cpu") raw_sd = ckpt.get("state_dict", ckpt) @@ -148,6 +169,7 @@ def _load_model(ckpt_path: str, features: List[int], device: torch.device) -> nn # Sliding-window inference (no prior) # =================================================================== + def infer_noprior_single_afid( fid: int, ckpt_path: str, @@ -161,7 +183,9 @@ def infer_noprior_single_afid( if features is None: features = [16, 32, 64] - device = torch.device(device_str if (device_str == "cpu" or torch.cuda.is_available()) else "cpu") + device = torch.device( + device_str if (device_str == "cpu" or torch.cuda.is_available()) else "cpu" + ) # ---- Image ---- image_np = img.get_fdata().astype(np.float32) @@ -182,10 +206,13 @@ def infer_noprior_single_afid( # ---- Gaussian map ---- ps = patch_size + def _g1d(n): - s = n * 0.125; c = n // 2 + s = n * 0.125 + c = n // 2 x = torch.arange(n, dtype=torch.float32) - return torch.exp(-((x - c) ** 2) / (2 * s ** 2)) + return torch.exp(-((x - c) ** 2) / (2 * s**2)) + gmap = torch.einsum("z,y,x->zyx", _g1d(ps), _g1d(ps), _g1d(ps)) gmap = gmap / gmap.max() @@ -196,8 +223,12 @@ def _g1d(n): for z in range(0, D, step): for y in range(0, H, step): for x in range(0, W, step): - ze = min(z + ps, D); ye = min(y + ps, H); xe = min(x + ps, W) - zs = max(0, ze - ps); ys = max(0, ye - ps); xs = max(0, xe - ps) + ze = min(z + ps, D) + ye = min(y + ps, H) + xe = min(x + ps, W) + zs = max(0, ze - ps) + ys = max(0, ye - ps) + xs = max(0, xe - ps) patch_coords.append((slice(zs, ze), slice(ys, ye), slice(xs, xe))) num_patches = len(patch_coords) @@ -220,7 +251,7 @@ def _g1d(n): t0 = time.perf_counter() predictions = [] for i in range(0, len(model_inputs), batch_size): - chunk = torch.stack(model_inputs[i: i + batch_size]).to(device) + chunk = torch.stack(model_inputs[i : i + batch_size]).to(device) with torch.inference_mode(): preds = model(chunk) predictions.extend(p.cpu() for p in preds) @@ -242,7 +273,7 @@ def _g1d(n): zo = max(0, (ps - az) // 2) yo = max(0, (ps - ay) // 2) xo = max(0, (ps - ax) // 2) - w = gmap[zo: zo + az, yo: yo + ay, xo: xo + ax] + w = gmap[zo : zo + az, yo : yo + ay, xo : xo + ax] pc = pred[0, :az, :ay, :ax] out[zs, ys, xs] += pc * w cnt[zs, ys, xs] += w @@ -270,37 +301,41 @@ def _g1d(n): # Snakemake entry point # =================================================================== -fid = int(snakemake.wildcards.afid) # noqa: F821 +fid = int(snakemake.wildcards.afid) # noqa: F821 cfg_block = snakemake.config.get("afids_inference", {}) # noqa: F821 if not cfg_block: raise ValueError("Missing 'afids_inference' block in snakebids.yml.") # Checkpoint path resolved by the rule via workflow.basedir -ckpt_path = snakemake.params.ckpt_path # noqa: F821 +ckpt_path = snakemake.params.ckpt_path # noqa: F821 if not Path(ckpt_path).exists(): raise FileNotFoundError(f"Checkpoint for AFID {fid:02d} not found: {ckpt_path}") -img = nib.nifti1.load(snakemake.input.t1w) # noqa: F821 +img = nib.nifti1.load(snakemake.input.t1w) # noqa: F821 pred_world, prob_map = infer_noprior_single_afid( - fid = fid, - ckpt_path = ckpt_path, - img = img, - patch_size = cfg_block.get("patch_size", 64), - batch_size = snakemake.config.get("inference_batch_size") or 7, # noqa: F821 - overlap = snakemake.config["inference_overlap"] if snakemake.config.get("inference_overlap") is not None else cfg_block.get("overlap", 0.5), # noqa: F821 - device_str = cfg_block.get("device", "cpu"), + fid=fid, + ckpt_path=ckpt_path, + img=img, + patch_size=cfg_block.get("patch_size", 64), + batch_size=snakemake.config.get("inference_batch_size") or 7, # noqa: F821 + overlap=( + snakemake.config["inference_overlap"] + if snakemake.config.get("inference_overlap") is not None + else cfg_block.get("overlap", 0.5) + ), # noqa: F821 + device_str=cfg_block.get("device", "cpu"), ) # Write x y z to a plain-text coord file Path(snakemake.output.coord).parent.mkdir(parents=True, exist_ok=True) # noqa: F821 -with open(snakemake.output.coord, "w") as f: # noqa: F821 +with open(snakemake.output.coord, "w") as f: # noqa: F821 f.write(f"{pred_world[0]} {pred_world[1]} {pred_world[2]}\n") # Save probability map as NIfTI (same affine/header as input image) -Path(snakemake.output.prob).parent.mkdir(parents=True, exist_ok=True) # noqa: F821 -nib.save( # noqa: F821 +Path(snakemake.output.prob).parent.mkdir(parents=True, exist_ok=True) # noqa: F821 +nib.save( # noqa: F821 nib.Nifti1Image(prob_map.numpy(), img.affine, img.header), - snakemake.output.prob, # noqa: F821 + snakemake.output.prob, # noqa: F821 ) diff --git a/autoafids/workflow/scripts/apply_with_prior_all.py b/autoafids/workflow/scripts/apply_with_prior_all.py index d7f7c70..c48dc9d 100644 --- a/autoafids/workflow/scripts/apply_with_prior_all.py +++ b/autoafids/workflow/scripts/apply_with_prior_all.py @@ -58,12 +58,15 @@ # FCSV helpers # =================================================================== + def load_fcsv(fcsv_path) -> pd.DataFrame: return pd.read_csv(fcsv_path, sep=",", header=2) def get_fid(fcsv_df: pd.DataFrame, fid_label: int) -> NDArray: - return fcsv_df.loc[fid_label - 1, ["x", "y", "z"]].to_numpy(dtype="single", copy=True) + return fcsv_df.loc[fid_label - 1, ["x", "y", "z"]].to_numpy( + dtype="single", copy=True + ) def fid_world2voxel(fid_world: NDArray, nii_affine: NDArray) -> NDArray: @@ -80,6 +83,7 @@ def fid_voxel2world(fid_voxel: NDArray, nii_affine: NDArray) -> NDArray: # (encoder_blocks, decoder_blocks, deep_supervision_heads) # =================================================================== + class ConvBlock(nn.Module): def __init__(self, in_channels: int, out_channels: int) -> None: super().__init__() @@ -91,14 +95,19 @@ def __init__(self, in_channels: int, out_channels: int) -> None: nn.InstanceNorm3d(out_channels, affine=True), nn.LeakyReLU(negative_slope=0.01, inplace=True), ) - def forward(self, x): return self.conv_block(x) + + def forward(self, x): + return self.conv_block(x) class EncoderBlock(nn.Module): def __init__(self, in_channels: int, out_channels: int) -> None: super().__init__() self.conv_block = ConvBlock(in_channels, out_channels) - self.downsample = nn.Conv3d(out_channels, out_channels, kernel_size=3, stride=2, padding=1) + self.downsample = nn.Conv3d( + out_channels, out_channels, kernel_size=3, stride=2, padding=1 + ) + def forward(self, x): skip = self.conv_block(x) return skip, self.downsample(skip) @@ -107,14 +116,19 @@ def forward(self, x): class DecoderBlock(nn.Module): def __init__(self, in_channels: int, out_channels: int) -> None: super().__init__() - self.upsample = nn.ConvTranspose3d(in_channels, in_channels // 2, kernel_size=2, stride=2) + self.upsample = nn.ConvTranspose3d( + in_channels, in_channels // 2, kernel_size=2, stride=2 + ) self.conv_block = ConvBlock(in_channels, out_channels) - def forward(self, x, skip): return self.conv_block(torch.cat([self.upsample(x), skip], 1)) + + def forward(self, x, skip): + return self.conv_block(torch.cat([self.upsample(x), skip], 1)) class nnUNet(nn.Module): - def __init__(self, in_channels: int, out_channels: int, - features: Optional[List[int]] = None) -> None: + def __init__( + self, in_channels: int, out_channels: int, features: Optional[List[int]] = None + ) -> None: super().__init__() if features is None: features = [32, 64, 128, 256, 320] @@ -130,12 +144,15 @@ def __init__(self, in_channels: int, out_channels: int, self.decoder_blocks.append(DecoderBlock(rev[i], rev[i + 1])) self.deep_supervision_heads = nn.ModuleList() for feat in reversed(features[:-1]): - self.deep_supervision_heads.append(nn.Conv3d(feat, out_channels, kernel_size=1)) + self.deep_supervision_heads.append( + nn.Conv3d(feat, out_channels, kernel_size=1) + ) def forward(self, x): skips = [] for enc in self.encoder_blocks: - skip, x = enc(x); skips.append(skip) + skip, x = enc(x) + skips.append(skip) x = self.bottleneck(x) skips = list(reversed(skips)) for i, dec in enumerate(self.decoder_blocks): @@ -145,12 +162,19 @@ def forward(self, x): class nnUNet_VanillaUNet(nnUNet): """Thin wrapper matching the Lightning checkpoint's class.""" - def __init__(self, in_channels: int = 1, out_channels: int = 1, - features: Optional[List[int]] = None) -> None: + + def __init__( + self, + in_channels: int = 1, + out_channels: int = 1, + features: Optional[List[int]] = None, + ) -> None: super().__init__(in_channels, out_channels, features) -def _load_model(ckpt_path: str, features: List[int], device: torch.device) -> nnUNet_VanillaUNet: +def _load_model( + ckpt_path: str, features: List[int], device: torch.device +) -> nnUNet_VanillaUNet: model = nnUNet_VanillaUNet(in_channels=1, out_channels=1, features=features) ckpt = torch.load(ckpt_path, map_location="cpu") raw_sd = ckpt.get("state_dict", ckpt) @@ -166,6 +190,7 @@ def _load_model(ckpt_path: str, features: List[int], device: torch.device) -> nn # Single-AFID inference # =================================================================== + def infer_single_afid( fid: int, ckpt_path: str, @@ -186,7 +211,9 @@ def infer_single_afid( if features is None: features = [16, 32, 64] - device = torch.device(device_str if (device_str == "cpu" or torch.cuda.is_available()) else "cpu") + device = torch.device( + device_str if (device_str == "cpu" or torch.cuda.is_available()) else "cpu" + ) # ---- Image ---- image_np = img.get_fdata().astype(np.float32) @@ -215,30 +242,35 @@ def infer_single_afid( # ---- Gaussian map ---- ps = patch_size + def _g1d(n): - s = n * 0.125; c = n // 2 + s = n * 0.125 + c = n // 2 x = torch.arange(n, dtype=torch.float32) - return torch.exp(-((x - c) ** 2) / (2 * s ** 2)) + return torch.exp(-((x - c) ** 2) / (2 * s**2)) + gmap = torch.einsum("z,y,x->zyx", _g1d(ps), _g1d(ps), _g1d(ps)) gmap = gmap / gmap.max() # ---- 7 patches: centre + ±x + ±y + ±z ---- half = ps // 2 offsets = [ - ( 0, 0, 0), # centre - ( 0, 0, -half), # left (x−) - ( 0, 0, +half), # right (x+) - ( 0, -half, 0), # up (y−) - ( 0, +half, 0), # down (y+) - (-half, 0, 0), # in (z−) - (+half, 0, 0), # out (z+) + (0, 0, 0), # centre + (0, 0, -half), # left (x−) + (0, 0, +half), # right (x+) + (0, -half, 0), # up (y−) + (0, +half, 0), # down (y+) + (-half, 0, 0), # in (z−) + (+half, 0, 0), # out (z+) ] patch_coords = [] for dz, dy, dx in offsets: - zs = max(0, min(cz+dz - half, D - ps)) - ys = max(0, min(cy+dy - half, H - ps)) - xs = max(0, min(cx+dx - half, W - ps)) - patch_coords.append((slice(zs, zs+ps), slice(ys, ys+ps), slice(xs, xs+ps))) + zs = max(0, min(cz + dz - half, D - ps)) + ys = max(0, min(cy + dy - half, H - ps)) + xs = max(0, min(cx + dx - half, W - ps)) + patch_coords.append( + (slice(zs, zs + ps), slice(ys, ys + ps), slice(xs, xs + ps)) + ) # ---- Extract & normalise ---- t0 = time.perf_counter() @@ -260,7 +292,7 @@ def _g1d(n): cpu0 = time.process_time() predictions = [] for i in range(0, len(model_inputs), batch_size): - chunk = torch.stack(model_inputs[i: i + batch_size]).to(device) + chunk = torch.stack(model_inputs[i : i + batch_size]).to(device) with torch.inference_mode(): preds = model(chunk) predictions.extend(p.cpu() for p in preds) @@ -335,7 +367,9 @@ def _infer_single_from_state( zs = max(0, min(cz + dz - half, D - ps)) ys = max(0, min(cy + dy - half, H - ps)) xs = max(0, min(cx + dx - half, W - ps)) - patch_coords.append((slice(zs, zs + ps), slice(ys, ys + ps), slice(xs, xs + ps))) + patch_coords.append( + (slice(zs, zs + ps), slice(ys, ys + ps), slice(xs, xs + ps)) + ) # ---- Extract & normalise ---- t0 = time.perf_counter() @@ -393,7 +427,6 @@ def _infer_single_from_state( return pred_world, out - # =================================================================== # Snakemake entry point (All 32 AFIDs) # =================================================================== @@ -403,25 +436,58 @@ def _infer_single_from_state( import csv AFIDS_FIELDNAMES = [ - "id", "x", "y", "z", "ow", "ox", "oy", "oz", "vis", "sel", "lock", - "label", "desc", "associatedNodeID", + "id", + "x", + "y", + "z", + "ow", + "ox", + "oy", + "oz", + "vis", + "sel", + "lock", + "label", + "desc", + "associatedNodeID", ] AFID_DESCRIPTIONS = [ - "AC", "PC", "Infracollicular Sulcus", "PMJ", - "Superior IPF", "Right Superior LMS", "Left Superior LMS", - "Right Inferior LMS", "Left Inferior LMS", "Culmen", - "Intermammillary Sulcus", "Right Mammilary Body", "Left Mammilary Body", - "Pineal Gland", "Right LV at AC", "Left LV at AC", - "Right LV at PC", "Left LV at PC", "Genu of CC", "Splenium of CC", - "Right AL Temporal Horn", "Left AL Tempral Horn", - "R. Sup. AM Temporal Horn", "L. Sup. AM Temporal Horn", - "R Inf. AM Temp Horn", "L Inf. AM Temp Horn", - "Right IG Origin", "Left IG Origin", - "R Ventral Occipital Horn", "L Ventral Occipital Horn", - "R Olfactory Fundus", "L Olfactory Fundus", + "AC", + "PC", + "Infracollicular Sulcus", + "PMJ", + "Superior IPF", + "Right Superior LMS", + "Left Superior LMS", + "Right Inferior LMS", + "Left Inferior LMS", + "Culmen", + "Intermammillary Sulcus", + "Right Mammilary Body", + "Left Mammilary Body", + "Pineal Gland", + "Right LV at AC", + "Left LV at AC", + "Right LV at PC", + "Left LV at PC", + "Genu of CC", + "Splenium of CC", + "Right AL Temporal Horn", + "Left AL Tempral Horn", + "R. Sup. AM Temporal Horn", + "L. Sup. AM Temporal Horn", + "R Inf. AM Temp Horn", + "L Inf. AM Temp Horn", + "Right IG Origin", + "Left IG Origin", + "R Ventral Occipital Horn", + "L Ventral Occipital Horn", + "R Olfactory Fundus", + "L Olfactory Fundus", ] + def write_combined_fcsv(afid_coords, fcsv_output): header_lines = [ "# Markups fiducial file version = 4.10\n", @@ -432,12 +498,24 @@ def write_combined_fcsv(afid_coords, fcsv_output): for lbl in range(1, 33): c = afid_coords.get(lbl, np.array([0.0, 0.0, 0.0])) desc = AFID_DESCRIPTIONS[lbl - 1] if lbl <= len(AFID_DESCRIPTIONS) else "" - rows.append({ - "id": lbl, "x": c[0], "y": c[1], "z": c[2], - "ow": "0.000", "ox": "0.000", "oy": "0.000", "oz": "1.000", - "vis": 1, "sel": 1, "lock": 1, - "label": lbl, "desc": desc, "associatedNodeID": "", - }) + rows.append( + { + "id": lbl, + "x": c[0], + "y": c[1], + "z": c[2], + "ow": "0.000", + "ox": "0.000", + "oy": "0.000", + "oz": "1.000", + "vis": 1, + "sel": 1, + "lock": 1, + "label": lbl, + "desc": desc, + "associatedNodeID": "", + } + ) with Path(fcsv_output).open("w", encoding="utf-8", newline="") as f: for line in header_lines: f.write(line) @@ -445,6 +523,7 @@ def write_combined_fcsv(afid_coords, fcsv_output): for row in rows: writer.writerow(row) + cfg_block = snakemake.config.get("afids_inference", {}) if not cfg_block: raise ValueError("Missing 'afids_inference' block in snakebids.yml.") @@ -468,17 +547,23 @@ def write_combined_fcsv(afid_coords, fcsv_output): # Precompute Gaussian map once ps = int(cfg_block.get("patch_size", 64)) + + def _g1d_all(n: int) -> torch.Tensor: s = n * 0.125 c = n // 2 x = torch.arange(n, dtype=torch.float32) - return torch.exp(-((x - c) ** 2) / (2 * s ** 2)) + return torch.exp(-((x - c) ** 2) / (2 * s**2)) + + gmap = torch.einsum("z,y,x->zyx", _g1d_all(ps), _g1d_all(ps), _g1d_all(ps)) gmap = gmap / gmap.max() batch_size = snakemake.config.get("inference_batch_size") or 7 device_str = cfg_block.get("device", "cpu") -device = torch.device(device_str if (device_str == "cpu" or torch.cuda.is_available()) else "cpu") +device = torch.device( + device_str if (device_str == "cpu" or torch.cuda.is_available()) else "cpu" +) # Optional: cache models by checkpoint path features = cfg_block.get("features", [16, 32, 64]) diff --git a/autoafids/workflow/scripts/apply_with_prior_single.py b/autoafids/workflow/scripts/apply_with_prior_single.py index b260c27..363066e 100644 --- a/autoafids/workflow/scripts/apply_with_prior_single.py +++ b/autoafids/workflow/scripts/apply_with_prior_single.py @@ -61,12 +61,15 @@ # FCSV helpers # =================================================================== + def load_fcsv(fcsv_path) -> pd.DataFrame: return pd.read_csv(fcsv_path, sep=",", header=2) def get_fid(fcsv_df: pd.DataFrame, fid_label: int) -> NDArray: - return fcsv_df.loc[fid_label - 1, ["x", "y", "z"]].to_numpy(dtype="single", copy=True) + return fcsv_df.loc[fid_label - 1, ["x", "y", "z"]].to_numpy( + dtype="single", copy=True + ) def fid_world2voxel(fid_world: NDArray, nii_affine: NDArray) -> NDArray: @@ -83,6 +86,7 @@ def fid_voxel2world(fid_voxel: NDArray, nii_affine: NDArray) -> NDArray: # (encoder_blocks, decoder_blocks, deep_supervision_heads) # =================================================================== + class ConvBlock(nn.Module): def __init__(self, in_channels: int, out_channels: int) -> None: super().__init__() @@ -94,14 +98,19 @@ def __init__(self, in_channels: int, out_channels: int) -> None: nn.InstanceNorm3d(out_channels, affine=True), nn.LeakyReLU(negative_slope=0.01, inplace=True), ) - def forward(self, x): return self.conv_block(x) + + def forward(self, x): + return self.conv_block(x) class EncoderBlock(nn.Module): def __init__(self, in_channels: int, out_channels: int) -> None: super().__init__() self.conv_block = ConvBlock(in_channels, out_channels) - self.downsample = nn.Conv3d(out_channels, out_channels, kernel_size=3, stride=2, padding=1) + self.downsample = nn.Conv3d( + out_channels, out_channels, kernel_size=3, stride=2, padding=1 + ) + def forward(self, x): skip = self.conv_block(x) return skip, self.downsample(skip) @@ -110,14 +119,19 @@ def forward(self, x): class DecoderBlock(nn.Module): def __init__(self, in_channels: int, out_channels: int) -> None: super().__init__() - self.upsample = nn.ConvTranspose3d(in_channels, in_channels // 2, kernel_size=2, stride=2) + self.upsample = nn.ConvTranspose3d( + in_channels, in_channels // 2, kernel_size=2, stride=2 + ) self.conv_block = ConvBlock(in_channels, out_channels) - def forward(self, x, skip): return self.conv_block(torch.cat([self.upsample(x), skip], 1)) + + def forward(self, x, skip): + return self.conv_block(torch.cat([self.upsample(x), skip], 1)) class nnUNet(nn.Module): - def __init__(self, in_channels: int, out_channels: int, - features: Optional[List[int]] = None) -> None: + def __init__( + self, in_channels: int, out_channels: int, features: Optional[List[int]] = None + ) -> None: super().__init__() if features is None: features = [32, 64, 128, 256, 320] @@ -133,12 +147,15 @@ def __init__(self, in_channels: int, out_channels: int, self.decoder_blocks.append(DecoderBlock(rev[i], rev[i + 1])) self.deep_supervision_heads = nn.ModuleList() for feat in reversed(features[:-1]): - self.deep_supervision_heads.append(nn.Conv3d(feat, out_channels, kernel_size=1)) + self.deep_supervision_heads.append( + nn.Conv3d(feat, out_channels, kernel_size=1) + ) def forward(self, x): skips = [] for enc in self.encoder_blocks: - skip, x = enc(x); skips.append(skip) + skip, x = enc(x) + skips.append(skip) x = self.bottleneck(x) skips = list(reversed(skips)) for i, dec in enumerate(self.decoder_blocks): @@ -148,12 +165,19 @@ def forward(self, x): class nnUNet_VanillaUNet(nnUNet): """Thin wrapper matching the Lightning checkpoint's class.""" - def __init__(self, in_channels: int = 1, out_channels: int = 1, - features: Optional[List[int]] = None) -> None: + + def __init__( + self, + in_channels: int = 1, + out_channels: int = 1, + features: Optional[List[int]] = None, + ) -> None: super().__init__(in_channels, out_channels, features) -def _load_model(ckpt_path: str, features: List[int], device: torch.device) -> nnUNet_VanillaUNet: +def _load_model( + ckpt_path: str, features: List[int], device: torch.device +) -> nnUNet_VanillaUNet: model = nnUNet_VanillaUNet(in_channels=1, out_channels=1, features=features) ckpt = torch.load(ckpt_path, map_location="cpu") raw_sd = ckpt.get("state_dict", ckpt) @@ -169,6 +193,7 @@ def _load_model(ckpt_path: str, features: List[int], device: torch.device) -> nn # Single-AFID inference # =================================================================== + def infer_single_afid( fid: int, ckpt_path: str, @@ -189,7 +214,9 @@ def infer_single_afid( if features is None: features = [16, 32, 64] - device = torch.device(device_str if (device_str == "cpu" or torch.cuda.is_available()) else "cpu") + device = torch.device( + device_str if (device_str == "cpu" or torch.cuda.is_available()) else "cpu" + ) # ---- Image ---- image_np = img.get_fdata().astype(np.float32) @@ -218,30 +245,35 @@ def infer_single_afid( # ---- Gaussian map ---- ps = patch_size + def _g1d(n): - s = n * 0.125; c = n // 2 + s = n * 0.125 + c = n // 2 x = torch.arange(n, dtype=torch.float32) - return torch.exp(-((x - c) ** 2) / (2 * s ** 2)) + return torch.exp(-((x - c) ** 2) / (2 * s**2)) + gmap = torch.einsum("z,y,x->zyx", _g1d(ps), _g1d(ps), _g1d(ps)) gmap = gmap / gmap.max() # ---- 7 patches: centre + ±x + ±y + ±z ---- half = ps // 2 offsets = [ - ( 0, 0, 0), # centre - ( 0, 0, -half), # left (x−) - ( 0, 0, +half), # right (x+) - ( 0, -half, 0), # up (y−) - ( 0, +half, 0), # down (y+) - (-half, 0, 0), # in (z−) - (+half, 0, 0), # out (z+) + (0, 0, 0), # centre + (0, 0, -half), # left (x−) + (0, 0, +half), # right (x+) + (0, -half, 0), # up (y−) + (0, +half, 0), # down (y+) + (-half, 0, 0), # in (z−) + (+half, 0, 0), # out (z+) ] patch_coords = [] for dz, dy, dx in offsets: - zs = max(0, min(cz+dz - half, D - ps)) - ys = max(0, min(cy+dy - half, H - ps)) - xs = max(0, min(cx+dx - half, W - ps)) - patch_coords.append((slice(zs, zs+ps), slice(ys, ys+ps), slice(xs, xs+ps))) + zs = max(0, min(cz + dz - half, D - ps)) + ys = max(0, min(cy + dy - half, H - ps)) + xs = max(0, min(cx + dx - half, W - ps)) + patch_coords.append( + (slice(zs, zs + ps), slice(ys, ys + ps), slice(xs, xs + ps)) + ) # ---- Extract & normalise ---- t0 = time.perf_counter() @@ -263,7 +295,7 @@ def _g1d(n): p0 = time.process_time() predictions = [] for i in range(0, len(model_inputs), batch_size): - chunk = torch.stack(model_inputs[i: i + batch_size]).to(device) + chunk = torch.stack(model_inputs[i : i + batch_size]).to(device) with torch.inference_mode(): preds = model(chunk) predictions.extend(p.cpu() for p in preds) @@ -301,37 +333,37 @@ def _g1d(n): # Snakemake entry point # =================================================================== -fid = int(snakemake.wildcards.afid) # noqa: F821 +fid = int(snakemake.wildcards.afid) # noqa: F821 cfg_block = snakemake.config.get("afids_inference", {}) # noqa: F821 if not cfg_block: raise ValueError("Missing 'afids_inference' block in snakebids.yml.") # Checkpoint path resolved by the rule via workflow.basedir -ckpt_path = snakemake.params.ckpt_path # noqa: F821 +ckpt_path = snakemake.params.ckpt_path # noqa: F821 if not Path(ckpt_path).exists(): raise FileNotFoundError(f"Checkpoint for AFID {fid:02d} not found: {ckpt_path}") -img = nib.nifti1.load(snakemake.input.t1w) # noqa: F821 +img = nib.nifti1.load(snakemake.input.t1w) # noqa: F821 -pred_world, prob_map = infer_single_afid( # noqa: F821 - fid = fid, - ckpt_path = ckpt_path, - img = img, - prior_fcsv_path = snakemake.input.prior, # noqa: F821 - patch_size = cfg_block.get("patch_size", 64), - batch_size = snakemake.config.get("inference_batch_size") or 7, # noqa: F821 - device_str = cfg_block.get("device", "cpu"), +pred_world, prob_map = infer_single_afid( # noqa: F821 + fid=fid, + ckpt_path=ckpt_path, + img=img, + prior_fcsv_path=snakemake.input.prior, # noqa: F821 + patch_size=cfg_block.get("patch_size", 64), + batch_size=snakemake.config.get("inference_batch_size") or 7, # noqa: F821 + device_str=cfg_block.get("device", "cpu"), ) # Write x y z to a plain-text coord file Path(snakemake.output.coord).parent.mkdir(parents=True, exist_ok=True) # noqa: F821 -with open(snakemake.output.coord, "w") as f: # noqa: F821 +with open(snakemake.output.coord, "w") as f: # noqa: F821 f.write(f"{pred_world[0]} {pred_world[1]} {pred_world[2]}\n") # Save probability map as NIfTI (same affine/header as input image) -Path(snakemake.output.prob).parent.mkdir(parents=True, exist_ok=True) # noqa: F821 -nib.save( # noqa: F821 +Path(snakemake.output.prob).parent.mkdir(parents=True, exist_ok=True) # noqa: F821 +nib.save( # noqa: F821 nib.Nifti1Image(prob_map.numpy(), img.affine, img.header), - snakemake.output.prob, # noqa: F821 + snakemake.output.prob, # noqa: F821 ) diff --git a/autoafids/workflow/scripts/gather_afids.py b/autoafids/workflow/scripts/gather_afids.py index cace4c0..cbe72a8 100644 --- a/autoafids/workflow/scripts/gather_afids.py +++ b/autoafids/workflow/scripts/gather_afids.py @@ -32,10 +32,20 @@ # =================================================================== AFIDS_FIELDNAMES = [ - "id", "x", "y", "z", - "ow", "ox", "oy", "oz", - "vis", "sel", "lock", - "label", "desc", "associatedNodeID", + "id", + "x", + "y", + "z", + "ow", + "ox", + "oy", + "oz", + "vis", + "sel", + "lock", + "label", + "desc", + "associatedNodeID", ] FCSV_TEMPLATE = ( @@ -93,5 +103,5 @@ def afids_to_fcsv(afid_coords: Dict[int, np.ndarray], fcsv_output) -> None: print(f"\n Writing combined FCSV with {len(afid_coords)} AFIDs...") Path(snakemake.output.fcsv).parent.mkdir(parents=True, exist_ok=True) # noqa: F821 -afids_to_fcsv(afid_coords, snakemake.output.fcsv) # noqa: F821 -print(f" Done → {snakemake.output.fcsv}") # noqa: F821 +afids_to_fcsv(afid_coords, snakemake.output.fcsv) # noqa: F821 +print(f" Done → {snakemake.output.fcsv}") # noqa: F821 diff --git a/autoafids/workflow/scripts/nnlm_to_fcsv.py b/autoafids/workflow/scripts/nnlm_to_fcsv.py index ff076e6..8519b4c 100644 --- a/autoafids/workflow/scripts/nnlm_to_fcsv.py +++ b/autoafids/workflow/scripts/nnlm_to_fcsv.py @@ -22,17 +22,17 @@ # ── Load nnLM voxel-space coordinates ──────────────────────────────────────── with open(snakemake.input.coords_json) as fh: - nnlm_raw = json.load(fh) # {"1": {"coordinates": [x,y,z], ...}, ...} + nnlm_raw = json.load(fh) # {"1": {"coordinates": [x,y,z], ...}, ...} # ── Load image affine (nibabel i,j,k = slice,row,col) ──────────────────────── img = nib.load(snakemake.input.t1w) -affine = img.affine # 4×4 voxel→RAS matrix +affine = img.affine # 4×4 voxel→RAS matrix # ── Convert voxel → RAS world coords ───────────────────────────────────────── # SimpleITK (x,y,z) == (col, row, slice) → nibabel (i,j,k) == (slice, row, col) afid_world: dict[int, np.ndarray] = {} for afid_str, entry in nnlm_raw.items(): - sx, sy, sz = entry["coordinates"] # SimpleITK x,y,z + sx, sy, sz = entry["coordinates"] # SimpleITK x,y,z # NiBabel expects (x, y, z) index mapping identical to SimpleITK ordering nib_ijk = np.array([sx, sy, sz, 1.0]) # 1) Get RAS world coordinates from the NIfTI affine (this matches true FCSV space) @@ -43,10 +43,20 @@ fcsv_template = Path(snakemake.params.fcsv_template) FIELDNAMES = [ - "id", "x", "y", "z", - "ow", "ox", "oy", "oz", - "vis", "sel", "lock", - "label", "desc", "associatedNodeID", + "id", + "x", + "y", + "z", + "ow", + "ox", + "oy", + "oz", + "vis", + "sel", + "lock", + "label", + "desc", + "associatedNodeID", ] with fcsv_template.open(encoding="utf-8", newline="") as fh: