diff --git a/.gitignore b/.gitignore index 6f86969..15abfc6 100644 --- a/.gitignore +++ b/.gitignore @@ -91,6 +91,9 @@ ipython_config.py # install all needed dependencies. #Pipfile.lock +# uv +uv.lock + # PEP 582; used by e.g. github.com/David-OConnor/pyflow __pypackages__/ diff --git a/Rhapso/data_prep/xml_to_dataframe.py b/Rhapso/data_prep/xml_to_dataframe.py index ec105c7..8d5b505 100644 --- a/Rhapso/data_prep/xml_to_dataframe.py +++ b/Rhapso/data_prep/xml_to_dataframe.py @@ -1,7 +1,8 @@ import pandas as pd import xml.etree.ElementTree as ET +import re -# This component recieves an XML file containing Tiff or Zarr image metadata and converts +# This component receives an XML file containing Tiff or Zarr image metadata and converts # it into several Dataframes class XMLToDataFrame: @@ -17,9 +18,17 @@ def parse_image_loader_zarr(self, root): for il in root.findall(".//ImageLoader/zgroups/zgroup"): view_setup = il.get("setup") - timepoint = il.get("timepoint") - file_path = il.find("path").text if il.find("path") is not None else None - channel = file_path.split("_ch_", 1)[1].split(".ome.zarr", 1)[0] + timepoint = il.get("tp") or il.get("timepoint") + file_path = il.get("path") + if file_path is None: + element_string = ET.tostring(il, encoding='unicode') + raise ValueError(f"zgroup element missing 'path' attribute: {element_string}") + + # default to channel 0 if not parseable from the path name + try: + channel = file_path.split("_ch_", 1)[1].split(".ome.zarr", 1)[0] + except (IndexError, AttributeError): + channel = 0 image_loader_data.append( { @@ -75,17 +84,99 @@ def parse_image_loader_tiff(self, root): # Convert the list to a DataFrame and return return pd.DataFrame(image_loader_data) - def parse_image_loader_split_zarr(self): - pass + def parse_image_loader_split_zarr(self, root): + """ + Parses a split.viewerimgloader XML structure where a single source image is virtually + subdivided into overlapping tiles via SetupIdDefinitions. + + Parameters + ---------- + root : xml.etree.ElementTree.Element + Root element of the parsed XML. + + Returns + ------- + pd.DataFrame + One row per split tile with columns: view_setup, timepoint, series, channel, + file_path, crop_min, crop_max, zarr_base_path. + """ + outer_loader = root.find(".//ImageLoader[@format='split.viewerimgloader']") + if outer_loader is None: + raise ValueError( + "split.viewerimgloader ImageLoader node not found in XML; " + "ensure the XML contains an ImageLoader with format='split.viewerimgloader'." + ) + + inner_loader = outer_loader.find("ImageLoader") + if inner_loader is None: + raise ValueError( + "Nested ImageLoader node not found inside split.viewerimgloader configuration." + ) + + zarr_elem = inner_loader.find("zarr") + if zarr_elem is None or zarr_elem.text is None: + raise ValueError( + " node with base path is missing from split.viewerimgloader configuration." + ) + + zarr_base_path = zarr_elem.text.strip() + # Build lookup from source setup id to (timepoint, zgroup_path) + zgroup_lookup = {} + for zg in inner_loader.findall(".//zgroups/zgroup"): + setup = zg.get("setup") + tp = zg.get("tp") or zg.get("timepoint") + path = zg.get("path") + zgroup_lookup[setup] = (tp, path) + + image_loader_data = [] + for sid in outer_loader.findall(".//SetupIds/SetupIdDefinition"): + new_id = sid.find("NewId").text.strip() + old_id = sid.find("OldId").text.strip() + crop_min = sid.find("min").text.strip() + crop_max = sid.find("max").text.strip() + + if old_id not in zgroup_lookup: + raise ValueError( + f"SetupIdDefinition refers to OldId {old_id!r} that is not present in the " + f"inner loader's zgroups. Available setup ids: {sorted(zgroup_lookup.keys())}" + ) + tp, zgroup_path = zgroup_lookup[old_id] + + # Attempt to extract the channel from the path, assuming filenames include '_ch_' + # (e.g. both '.zarr' and '.ome.zarr' variants). If this pattern is not present or is + # formatted differently, we deliberately fall back to channel 0 as a default. + channel_match = re.search(r'_ch_(\d+)', zgroup_path) + if channel_match: + channel = channel_match.group(1) + else: + # Default to channel 0 when channel information cannot be parsed from the path. + channel = 0 + + image_loader_data.append({ + "view_setup": new_id, + "timepoint": tp, + "series": 1, + "channel": channel, + "file_path": zgroup_path, + "crop_min": crop_min, + "crop_max": crop_max, + "zarr_base_path": zarr_base_path, + }) + + return pd.DataFrame(image_loader_data) def route_image_loader(self, root): """ Directs the XML parsing process based on the image loader format specified in the XML. """ format_node = root.find(".//ImageLoader") - format_type = format_node.get("format") + if format_node is None: + raise ValueError("No element found in XML; cannot determine image loader format.") - if "filemap" in format_type: + format_type = (format_node.get("format") or "").lower() + if "split" in format_type: + return self.parse_image_loader_split_zarr(root) + elif "filemap" in format_type: return self.parse_image_loader_tiff(root) else: return self.parse_image_loader_zarr(root) @@ -96,7 +187,7 @@ def parse_view_setups(self, root): """ viewsetups_data = [] - for vs in root.findall(".//ViewSetup"): + for vs in root.findall("./SequenceDescription/ViewSetups/ViewSetup"): id_ = vs.find("id").text # name = vs.find("name").text name = vs.findtext("name") diff --git a/Rhapso/detection/advanced_refinement.py b/Rhapso/detection/advanced_refinement.py index 1c9db5c..500c0d3 100644 --- a/Rhapso/detection/advanced_refinement.py +++ b/Rhapso/detection/advanced_refinement.py @@ -163,7 +163,9 @@ def filter(self): lb = row['lower_bound'] ub = row['upper_bound'] if vid == view_id: - to_process_interval = (lb, ub) + + ub_inclusive = (ub[0]+1, ub[1]+1, ub[2]+1) + to_process_interval = (lb, ub_inclusive) ips_block = [] intensities_block = [] diff --git a/Rhapso/detection/difference_of_gaussian.py b/Rhapso/detection/difference_of_gaussian.py index 1380f1c..c70584a 100644 --- a/Rhapso/detection/difference_of_gaussian.py +++ b/Rhapso/detection/difference_of_gaussian.py @@ -10,13 +10,19 @@ """ class DifferenceOfGaussian: - def __init__(self, min_intensity, max_intensity, sigma, threshold, median_filter, mip_map_downsample): + def __init__(self, min_intensity, max_intensity, sigma, threshold, median_filter, mip_map_downsample, + min_peak_intensity=None): self.min_intensity = min_intensity self.max_intensity = max_intensity self.sigma = sigma self.threshold = threshold self.median_filter = median_filter self.mip_map_downsample = mip_map_downsample + # Optional post-detection filter: drop peaks whose RAW (pre-background- + # subtraction) image intensity at the localized peak is below this + # threshold. None disables. Targets zero-padded borders in stitched + # / fused volumes that DoG latches onto as edge features. + self.min_peak_intensity = min_peak_intensity def apply_offset(self, peaks, offset_z): """ @@ -305,8 +311,13 @@ def run(self, image_chunk, offset, lb): """ Executes the entry point of the script. """ + # Keep a reference to the RAW chunk before background subtraction so + # the optional min_peak_intensity filter can sample true voxel + # values (zero-padded borders read 0, not the post-subtract residual). + raw_image_chunk = image_chunk image_chunk = self.background_subtract_xy(image_chunk) peaks = self.compute_difference_of_gaussian(image_chunk) + print(f"[DoG] image_chunk.shape={image_chunk.shape}, detected_peaks={len(peaks)}, offset={offset}, lb={lb}") if peaks.size == 0: intensities = np.empty((0,), dtype=image_chunk.dtype) @@ -314,10 +325,42 @@ def run(self, image_chunk, offset, lb): else: intensities = map_coordinates(image_chunk, peaks.T, order=1, mode='reflect') - final_peaks = self.apply_lower_bounds(peaks, lb) + if self.min_peak_intensity is not None: + # Reject peaks generated by a nearby data→0 step (the + # zero-border artifact in fused/stitched volumes). The DoG + # response peak from an edge step lands ~sigma INSIDE the + # bright side, where the raw value at the peak itself is + # positive — sampling only at the peak coord is therefore + # not enough. Check a (2r+1)^3 neighborhood and reject if + # the local minimum is below threshold (i.e. a zero voxel + # is in the support window). r = ceil(sigma) is enough to + # reach the step that generated the response. + r = max(1, int(np.ceil(float(self.sigma)))) + peaks_int = np.rint(peaks).astype(np.int32) + shape = np.asarray(raw_image_chunk.shape, dtype=np.int32) + lo = np.clip(peaks_int - r, 0, shape - 1) + hi = np.clip(peaks_int + r + 1, 1, shape) + keep = np.ones(len(peaks_int), dtype=bool) + thr = float(self.min_peak_intensity) + for i in range(len(peaks_int)): + z0, y0, x0 = lo[i] + z1, y1, x1 = hi[i] + if raw_image_chunk[z0:z1, y0:y1, x0:x1].min() < thr: + keep[i] = False + n_dropped = int((~keep).sum()) + if n_dropped: + print( + f"[DoG] min_peak_intensity={self.min_peak_intensity} " + f"(neighborhood r={r}): dropped {n_dropped}/{len(peaks)} " + f"peaks adjacent to sub-threshold voxels" + ) + peaks = peaks[keep] + intensities = intensities[keep] + final_peaks = self.upsample_coordinates(peaks) + final_peaks = self.apply_lower_bounds(final_peaks, lb) final_peaks = self.apply_offset(final_peaks, offset) - final_peaks = self.upsample_coordinates(final_peaks) + print(f"[DoG] final_peaks after transforms={len(final_peaks)}") return { 'interest_points': final_peaks, 'intensities': intensities diff --git a/Rhapso/detection/image_reader.py b/Rhapso/detection/image_reader.py index cbe7076..6dc6b98 100644 --- a/Rhapso/detection/image_reader.py +++ b/Rhapso/detection/image_reader.py @@ -9,6 +9,73 @@ Image Reader loads and downsamples Zarr and TIFF OME data """ + +def _per_axis_pyramid_ds_xyz(file_path: str, level: int): + """Return (ds_x, ds_y, ds_z) downsample factors for ``level`` vs L0. + + Source of truth: OME-zarr v0.4 ``coordinateTransformations.scale`` + metadata in the parent group's ``.zattrs`` — i.e. the pyramid + writer's declared per-axis sampling-density ratio. This is the + correct primitive (the metadata explicitly encodes whatever + anisotropy the pyramid has) and avoids integer-flooring slack from + array-shape ratios (e.g. dataset A L0_z=220 / L4_z=13 = 16.92, + while the metadata correctly says scale_z(L4)/scale_z(L0) = 16.0). + + Returns ``(None, None, None)`` when the metadata cannot be read — + caller is expected to fall back to legacy ``2 ** level`` behavior. + + ``file_path`` is the full path to the level-N array + (e.g. ``s3://…/channel_488.zarr/2``); the parent group is the + OME-zarr root carrying the multiscales metadata. + """ + try: + root_path = file_path.rstrip('/').rsplit('/', 1)[0] + root = zarr.open(root_path, mode='r') + scale_l0 = _ome_zarr_scale_zyx(root, "0") + scale_ln = _ome_zarr_scale_zyx(root, str(level)) + except Exception: + return None, None, None + if scale_l0 is None or scale_ln is None: + return None, None, None + # Per-axis ds = scale(L) / scale(L0). Round to int (≥ 1) since the + # caller ultimately uses these as integer divisors for voxel bounds. + sz0, sy0, sx0 = scale_l0 + szn, syn, sxn = scale_ln + ds_z = max(1, int(round(szn / max(sz0, 1e-12)))) + ds_y = max(1, int(round(syn / max(sy0, 1e-12)))) + ds_x = max(1, int(round(sxn / max(sx0, 1e-12)))) + return ds_x, ds_y, ds_z + + +def _ome_zarr_scale_zyx(root_group, level_name: str): + """Return ``(scale_z, scale_y, scale_x)`` from OME-zarr multiscales. + + Reads ``coordinateTransformations[type==scale]`` for the given + level path. Slices the trailing ZYX entries from a 3- or 5-axis + declaration. Returns ``None`` if the metadata is missing or + malformed — caller should treat that as "metadata unreadable" and + fall back to a legacy heuristic. + """ + try: + attrs = root_group.attrs.asdict() + multiscales = attrs.get("multiscales", []) + if not multiscales: + return None + for d in multiscales[0].get("datasets", []): + if str(d.get("path")) != str(level_name): + continue + for ct in d.get("coordinateTransformations", []): + if ct.get("type") == "scale": + s = ct.get("scale", []) + if len(s) == 5: + return float(s[2]), float(s[3]), float(s[4]) + if len(s) == 3: + return float(s[0]), float(s[1]), float(s[2]) + return None + except Exception: + return None + return None + class CustomBioImage(BioImage): def standard_metadata(self): pass @@ -82,15 +149,45 @@ def fetch_image_data(self, record, dsxy, dsz): dask_array = img.get_dask_stack()[0, 0, 0, :, :, :] elif self.file_type == "zarr": - s3 = s3fs.S3FileSystem(anon=False) full_path = f"{file_path}" - store = s3fs.S3Map(root=full_path, s3=s3) - zarr_array = zarr.open(store, mode='r') - dask_array = da.from_zarr(zarr_array)[0, 0, :, :, :] + is_local = not full_path.startswith("s3://") + try: + if is_local: + zarr_array = zarr.open(full_path, mode='r') + else: + s3 = s3fs.S3FileSystem(anon=False) + store = s3fs.S3Map(root=full_path, s3=s3) + zarr_array = zarr.open(store, mode='r') + if zarr_array.ndim == 5: + dask_array = da.from_zarr(zarr_array)[0, 0, :, :, :] + elif zarr_array.ndim == 3: + dask_array = da.from_zarr(zarr_array) + else: + raise ValueError(f"Expected 3D or 5D zarr, got {zarr_array.ndim}D with shape {zarr_array.shape}") + except Exception as e: + print(f"[ImageReader] ERROR opening zarr at {full_path}: {e}") + # Try to inspect root to show available multiscales + try: + root_path = full_path.rsplit('/', 1)[0] + print(f"[ImageReader] Attempting to inspect root zarr at: {root_path}") + if is_local: + root_zarr = zarr.open(root_path, mode='r') + else: + root_store = s3fs.S3Map(root=root_path, s3=s3) + root_zarr = zarr.open(root_store, mode='r') + available_levels = list(root_zarr.keys()) if hasattr(root_zarr, 'keys') else 'unknown' + print(f"[ImageReader] Available multiscale levels at root: {available_levels}") + except Exception as e2: + print(f"[ImageReader] Could not inspect root zarr: {e2}") + raise dask_array = dask_array.astype(np.float32) dask_array = dask_array.transpose() + # Store original crop bounds (in level-0 coordinates) for later application + crop_min = record.get('crop_min') + crop_max = record.get('crop_max') + # Downsample Dask array downsampled_stack = self.interface_downsampling(dask_array, dsxy, dsz) @@ -98,13 +195,144 @@ def fetch_image_data(self, record, dsxy, dsz): lb = list(interval_key[0]) ub = list(interval_key[1]) + # Bounds are in full-resolution (level 0) coordinates. + # We loaded from a potentially downsampled multiscale level, + # so we need to scale the bounds down to that level's voxel + # space. Anisotropic-pyramid-safe: compute per-axis ds from + # actual ``shape(L0) / shape(level)`` rather than the legacy + # isotropic ``2 ** level`` (which broke on pyramids that + # preserve Z at coarse levels — e.g. HCR_823476_s5 keeps Z + # full-res at L1/L2 while halving XY, causing the legacy code + # to read only the top quarter of Z; see + # new_reports/11_ANISOTROPIC_PYRAMID_BUG.md). + # + # ``lb``/``ub`` are XYZ-ordered, but zarr ``shape`` is ZYX — + # the indexing below is explicit on that axis swap. + try: + level_str = file_path.rstrip('/').split('/')[-1] + level = int(level_str) + print(f"[ImageReader] file_path={file_path}, extracted level={level}") + print(f"[ImageReader] Before scaling: lb={lb}, ub={ub}, downsampled_stack.shape={downsampled_stack.shape}") + if level > 0: + ds_x, ds_y, ds_z = _per_axis_pyramid_ds_xyz(file_path, level) + if ds_x is not None: + lb = [lb[0] // ds_x, lb[1] // ds_y, lb[2] // ds_z] + ub = [ub[0] // ds_x, ub[1] // ds_y, ub[2] // ds_z] + print( + f"[ImageReader] After per-axis scaling " + f"(ds_xyz=({ds_x},{ds_y},{ds_z})): lb={lb}, ub={ub}" + ) + else: + # Fallback: legacy isotropic behavior (e.g. when the + # parent pyramid metadata isn't accessible). + scale = 2 ** level + lb = [x // scale for x in lb] + ub = [x // scale for x in ub] + print( + f"[ImageReader] After scaling by 2^{level}={scale} " + f"(fallback, parent pyramid not readable): " + f"lb={lb}, ub={ub}" + ) + except (ValueError, IndexError) as e: + print(f"[ImageReader] Level extraction failed ({e}); using bounds as-is") + pass # Level extraction failed; use bounds as-is + + # Now apply split tile crop if present (using scaled crop bounds) + if crop_min is not None and crop_max is not None: + if len(crop_min) != 3 or len(crop_max) != 3: + raise ValueError( + f"crop_min and crop_max must both be length 3 for 3D cropping; " + f"got crop_min={crop_min}, crop_max={crop_max}" + ) + + # Scale crop bounds from level-0 coordinates to downsampled array + # coordinates. The array has been downsampled by the pyramid + # level (anisotropic-safe per-axis ds — see comment above) + # AND by dsxy/dsz (interface_downsampling), so crop bounds + # must be divided by the total per-axis factor. + try: + level_str = file_path.rstrip('/').split('/')[-1] + level = int(level_str) + if level > 0: + ds_x_p, ds_y_p, ds_z_p = _per_axis_pyramid_ds_xyz( + file_path, level + ) + if ds_x_p is None: + # Same legacy fallback as the lb/ub block above. + ds_x_p = ds_y_p = ds_z_p = 2 ** level + else: + ds_x_p = ds_y_p = ds_z_p = 1 + total_scale_x = ds_x_p * dsxy + total_scale_y = ds_y_p * dsxy + total_scale_z = ds_z_p * dsz + except (ValueError, IndexError): + total_scale_x = total_scale_y = dsxy + total_scale_z = dsz + + # crop bounds are in XYZ order: [0]=X, [1]=Y, [2]=Z. + scales = [total_scale_x, total_scale_y, total_scale_z] + crop_min_scaled = [int(x // s) for x, s in zip(crop_min, scales)] + crop_max_scaled = [int(np.ceil((x + 1) / s) - 1) for x, s in zip(crop_max, scales)] + + # Validate and clamp crop bounds to downsampled array dimensions + array_shape = downsampled_stack.shape + for i in range(3): + if crop_min_scaled[i] < 0: + raise ValueError( + f"crop_min_scaled[{i}]={crop_min_scaled[i]} is negative" + ) + # Clamp crop_max to valid range + crop_max_scaled[i] = min(crop_max_scaled[i], array_shape[i] - 1) + if crop_min_scaled[i] > crop_max_scaled[i]: + raise ValueError( + f"crop_min_scaled[{i}]={crop_min_scaled[i]} > crop_max_scaled[{i}]={crop_max_scaled[i]}" + ) + + print(f"[ImageReader] Applying crop: crop_min_scaled={crop_min_scaled}, crop_max_scaled={crop_max_scaled}") + downsampled_stack = downsampled_stack[ + crop_min_scaled[0]:crop_max_scaled[0] + 1, + crop_min_scaled[1]:crop_max_scaled[1] + 1, + crop_min_scaled[2]:crop_max_scaled[2] + 1 + ] + # After cropping, detected peaks land at coords relative to + # the CROPPED chunk's origin (0..chunk_size). The subsequent + # ``DoG.apply_lower_bounds(peaks, lb)`` step adds ``lb`` + # elementwise — that addition happens BEFORE + # ``upsample_coordinates`` scales up to L0, so ``lb`` must + # be in the same array-voxel unit system as the peaks. + # Without this correction, ``lower_bound`` is (0,0,0) in + # split-tile mode (it comes from the tile-local + # ``_split_tile_shape`` bounds which always start at zero), + # so every tile's peaks re-origin to (0,0,0) and all tiles' + # IPs collapse into the global frame's top-left corner — + # visible in ``03-tile_edge_filter/moving.png`` as IPs + # clustered in the 0..tile_size region regardless of which + # grid cell the tile is supposed to cover. Add + # ``crop_min_scaled`` so the crop offset propagates into + # the peak coordinate transform. + # + # COORDINATE-FRAME CONTRACT (split-tile IPs): + # The stored N5 IPs produced downstream of this adjustment + # are in L0 WORLD voxel coords (tile grid position baked + # in). Matching MUST skip the "Image Splitting" + # ViewTransform when composing per-view transforms — see + # ``Rhapso/matching/load_and_transform_points.py`` + # (``SPLIT_TILE_TRANSFORM_NAME``). Applying it again would + # double-translate each split tile's IPs and produce a + # residual gradient of k × tile_step across the grid. + lower_bound = [ + int(lower_bound[i]) + int(crop_min[i]) + for i in range(3) + ] + print(f"[ImageReader] crop offset applied → lower_bound={lower_bound}") + # Load image chunk into mem downsampled_image_chunk = downsampled_stack[lb[0]:ub[0]+1, lb[1]:ub[1]+1, lb[2]:ub[2]+1].compute() - + interval_key = ( tuple(lb), tuple(ub), - tuple((ub[0] - lb[0]+1, ub[1] - lb[1]+1, ub[2] - lb[2]+1)) + tuple((ub[0] - lb[0]+1, ub[1] - lb[1]+1, ub[2] - lb[2]+1)) ) return view_id, interval_key, downsampled_image_chunk, offset, lower_bound diff --git a/Rhapso/detection/metadata_builder.py b/Rhapso/detection/metadata_builder.py index 203a5be..85d635d 100644 --- a/Rhapso/detection/metadata_builder.py +++ b/Rhapso/detection/metadata_builder.py @@ -1,3 +1,4 @@ +import os import numpy as np """ @@ -17,11 +18,24 @@ def __init__(self, dataframes, overlapping_area, image_file_prefix, file_type, d self.chunks_per_bound = chunks_per_bound self.run_type = run_type self.level = level - self.overlap = int(np.ceil(3 * sigma)) + # Overlap must cover the full scipy gaussian_filter kernel radius. + # scipy uses truncate=4.0 by default: kernel radius = int(4*sigma + 0.5). + # The larger Gaussian in the DoG is sigma_2 = sigma * k where k = 2^(1/4), + # after accounting for image_sigma=0.5: sigma_2_eff = sqrt((sigma*k)^2 - 0.5^2). + # Using ceil(4 * sigma_2_eff) ensures no boundary artifact from truncation. + _k = 2 ** (1.0 / 4.0) + _sigma_2_eff = float(np.sqrt(max((sigma * _k) ** 2 - 0.5 ** 2, sigma ** 2))) + self.overlap = int(np.ceil(4.0 * _sigma_2_eff)) self.sub_region_chunking = not chunks_per_bound == 0 self.metadata = [] + print( + f"[MetadataBuilder] sigma={sigma} sigma_2_eff={_sigma_2_eff:.4f} " + f"overlap={self.overlap} chunks_per_bound={chunks_per_bound} " + f"sub_region_chunking={self.sub_region_chunking} " + f"num_cpus={os.cpu_count()}" + ) - def build_image_metadata(self, process_intervals, file_path, view_id): + def build_image_metadata(self, process_intervals, file_path, view_id, crop_min=None, crop_max=None): """ Builds list of metadata with optional sub-chunking """ @@ -41,7 +55,9 @@ def build_image_metadata(self, process_intervals, file_path, view_id): 'file_path': file_path, 'interval_key': interval_key, 'offset': 0, - 'lb': lb_fixed + 'lb': lb_fixed, + 'crop_min': crop_min, + 'crop_max': crop_max }) # Apply sub-region chunking @@ -68,13 +84,22 @@ def build_image_metadata(self, process_intervals, file_path, view_id): span = tuple(actual_ub[i] - actual_lb[i] for i in range(3)) interval_key = (actual_lb, actual_ub, span) + # `lb` is the chunk's parent-frame starting voxel + # (used by DoG.apply_lower_bounds to map peaks back + # to the parent's coord frame). `offset` is 0 + # because chunk-shift is already encoded in `lb`; + # leaving it as `z` would double-add the shift and + # produce Z-banded IPs (one band per chunk after + # the first). Image read uses interval_key bounds. self.metadata.append({ 'view_id': view_id, 'file_path': file_path, 'interval_key': interval_key, - 'offset': z, - 'lb' : lb - }) + 'offset': 0, + 'lb' : actual_lb, + 'crop_min': crop_min, + 'crop_max': crop_max + }) elif self.file_type == "zarr": @@ -91,37 +116,65 @@ def build_image_metadata(self, process_intervals, file_path, view_id): z = max(0, chunk[0] - self.overlap) z_end = min(chunk[-1] + 1 + self.overlap, z_stop - z_start) - actual_lb = (lb[0], lb[1], z_start + z) + actual_lb = (lb[0], lb[1], z_start + z) actual_ub = (ub[0], ub[1], z_start + z_end) span = tuple(actual_ub[i] - actual_lb[i] for i in range(3)) interval_key = (actual_lb, actual_ub, span) + # See tiff branch above: `lb` must be the chunk's + # parent-frame starting voxel and `offset` zero, + # else apply_lower_bounds + apply_offset double-add + # the chunk-shift and produce Z-banded IPs. self.metadata.append({ 'view_id': view_id, 'file_path': file_path, 'interval_key': interval_key, - 'offset': z, - 'lb' : lb - }) - + 'offset': 0, + 'lb' : actual_lb, + 'crop_min': crop_min, + 'crop_max': crop_max + }) + def build_paths(self): """ Iterates through views to interface metadata building """ + is_split = 'crop_min' in self.image_loader_df.columns + for _, row in self.image_loader_df.iterrows(): view_id = f"timepoint: {row['timepoint']}, setup: {row['view_setup']}" process_intervals = self.overlapping_area[view_id] - + if self.file_type == 'zarr': - file_path = self.image_file_prefix + row['file_path'] + f'/{self.level}' + if is_split: + # zarr_base_path is the root (e.g., SPIM.ome.zarr/), + # file_path has the per-tile name (e.g., Tile_X_..._ch_405.zarr). + # Multiscale levels live inside each tile zarr. + file_path = os.path.join(row['zarr_base_path'], row['file_path']) + print(f"[MetadataBuilder] split=True, zarr_base_path={row['zarr_base_path']}, per_tile={row['file_path']}, joined={file_path}") + else: + file_path = self.image_file_prefix + print(f"[MetadataBuilder] split=False, using image_file_prefix={file_path}") + # Append multiscale level if not already present + if self.level is not None and not str(file_path).rstrip('/').endswith(str(self.level)): + file_path = os.path.join(file_path, str(self.level)) + print(f"[MetadataBuilder] Appended level={self.level}, final path={file_path}") elif self.file_type == 'tiff': - file_path = self.image_file_prefix + row['file_path'] + file_path = os.path.join(self.image_file_prefix, row['file_path']) else: raise ValueError(f"Unsupported file_type: {self.file_type!r}") - + + # Extract crop bounds for split tiles (keep in level-0 coordinates; + # ImageReader will scale them by the total downsampling factor). + crop_min = None + crop_max = None + if is_split: + crop_min = [int(v) for v in row['crop_min'].split()] + crop_max = [int(v) for v in row['crop_max'].split()] + if self.run_type == 'ray': - self.build_image_metadata(process_intervals, file_path, view_id) + self.build_image_metadata(process_intervals, file_path, view_id, crop_min, crop_max) else: raise ValueError(f"Unsupported run type: {self.run_type!r}") diff --git a/Rhapso/detection/overlap_detection.py b/Rhapso/detection/overlap_detection.py index 2dca34c..d9feaa5 100644 --- a/Rhapso/detection/overlap_detection.py +++ b/Rhapso/detection/overlap_detection.py @@ -5,6 +5,7 @@ import s3fs import dask.array as da import math +import os """ Overlap Detection figures out where image tile overlap. @@ -22,12 +23,13 @@ def time_interval(self): pass class OverlapDetection(): - def __init__(self, transform_models, dataframes, dsxy, dsz, prefix, file_type): + def __init__(self, transform_models, dataframes, dsxy, dsz, prefix, file_type, overlapping_only=True): self.transform_models = transform_models self.image_loader_df = dataframes['image_loader'] self.dsxy, self.dsz = dsxy, dsz self.prefix = prefix self.file_type = file_type + self.overlapping_only = overlapping_only self.to_process = {} self.image_shape_cache = {} self.max_interval_size = 0 @@ -50,13 +52,25 @@ def load_image_metadata(self, file_path): return self.image_shape_cache[file_path] if self.file_type == 'zarr': - s3 = s3fs.S3FileSystem(anon=False) - store = s3fs.S3Map(root=file_path, s3=s3) - zarr_array = zarr.open(store, mode='r') - dask_array = da.from_zarr(zarr_array) - dask_array = da.expand_dims(dask_array, axis=2) - shape = dask_array.shape - self.image_shape_cache[file_path] = shape + s3 = s3fs.S3FileSystem(anon=True) + print(f"[OverlapDetection] Opening root zarr: {file_path}") + try: + store = s3fs.S3Map(root=file_path, s3=s3) + zarr_obj = zarr.open(store, mode='r') + if isinstance(zarr_obj, zarr.hierarchy.Group): + print(f"[OverlapDetection] Opened zarr Group. Available levels: {list(zarr_obj.keys())}") + zarr_arr = zarr_obj['0'] + else: + print(f"[OverlapDetection] Opened zarr Array directly.") + zarr_arr = zarr_obj + dask_array = da.from_zarr(zarr_arr) + dask_array = da.expand_dims(dask_array, axis=2) + shape = dask_array.shape + self.image_shape_cache[file_path] = shape + print(f"[OverlapDetection] Shape (after expand): {shape}") + except Exception as e: + print(f"[OverlapDetection] ERROR opening root zarr: {e}") + raise elif self.file_type == 'tiff': img = CustomBioImage(file_path, reader=bioio_tifffile.Reader) @@ -65,7 +79,28 @@ def load_image_metadata(self, file_path): self.image_shape_cache[file_path] = shape return shape - + + def _split_tile_shape(self, row): + """Derive 6D shape tuple from split tile crop bounds. + + Parameters + ---------- + row : pd.Series + Row from image_loader_df with 'crop_min' and 'crop_max' columns. + Values are space-separated "X Y Z" strings. + + Returns + ------- + tuple + 6D shape tuple (1, 1, 1, Z, Y, X) matching load_image_metadata format. + """ + cmin = list(map(int, row['crop_min'].split())) + cmax = list(map(int, row['crop_max'].split())) + x_size = cmax[0] - cmin[0] + 1 + y_size = cmax[1] - cmin[1] + 1 + z_size = cmax[2] - cmin[2] + 1 + return (1, 1, 1, z_size, y_size, x_size) + # def open_and_downsample(self, shape): # X = int(shape[5]) # Y = int(shape[4]) @@ -199,17 +234,192 @@ def floor_log2(self, n): return max(0, int(math.floor(math.log2(max(1, n))))) def choose_zarr_level(self): - """ - pick the highest power-of-two pyramid level ( ≤ 7) compatible with dsxy/dsz + """Pick the actual pyramid level closest to the requested dsxy/dsz. + + Reads per-axis downsample factors from the parent zarr group's + OME-zarr ``coordinateTransformations.scale`` metadata (the + pyramid writer's declared per-axis sampling-density ratio) — + NOT from an isotropic ``2 ** level`` assumption, NOT from + array-shape ratios. Metadata is the right primitive: it + explicitly encodes whatever anisotropy the pyramid has, and is + immune to integer-flooring slack at odd L0 extents (e.g. + dataset A L0_z=220, L4_z=13 → shape-ratio 16.92 vs metadata + 16.0; metadata is the writer's intent). + + Picks the level whose per-axis ds is the largest possible while + still ``≤`` the request on every axis (so the remaining + downsampling can be done in software without ever upsampling), + preferring the smallest leftover product to minimize redundant + downsampling work. + + Anisotropic-pyramid example (HCR_823476_s5, request dsxy=16, dsz=4): + L0: ds=(1,1,1) + L1: ds=(2,2,1) + L2: ds=(4,4,1) + L3: ds=(8,8,2) + L4: ds=(16,16,4) ← exact match, leftover=(1,1,1) — picked + L5: ds=(32,32,8) ← rejected (over-downsamples on every axis) + Legacy ``min(log2_xy, log2_z)`` would have picked L2 with + leftover ``(4, 4, 1)`` — pulling 64× more bytes from S3 and + re-doing the antialiasing in software. + + Falls back to legacy isotropic ``min(log2(dsxy), log2(dsz))`` when + the parent zarr's metadata can't be parsed (preserves prior + behavior on non-OME-zarr inputs / tests). Tuple convention for + ``leftovers`` is preserved as ``(_, dsxy_leftover, dsz_leftover)`` + so the call site at ``__call__`` is untouched — only the first + slot is unused-but-kept-for-shape. """ max_level = 7 - lvl_xy = self.floor_log2(self.dsxy) - lvl_z = self.floor_log2(self.dsz) - best = min(lvl_xy, lvl_z, max_level) - factor = 1 << best - leftovers = (max(1, self.dsxy // factor), max(1, self.dsxy // factor), max(1, self.dsz // factor)) - return best, leftovers + try: + root = zarr.open(self.prefix, mode='r') + scale_l0 = self._ome_scale_zyx(root, "0") + if scale_l0 is None: + raise ValueError("no scale metadata at L0") + + # Iterate every level declared in the multiscales metadata + # (NOT array_keys() — the metadata is what defines the + # pyramid's intent). + attrs = root.attrs.asdict() + datasets = attrs.get("multiscales", [{}])[0].get("datasets", []) + level_records = [] + for d in datasets: + level_name = str(d.get("path")) + if not level_name.isdigit(): + continue + lvl_int = int(level_name) + if lvl_int > max_level: + continue + scale_ln = self._ome_scale_zyx(root, level_name) + if scale_ln is None: + continue + # ds_axis = scale(L)[axis] / scale(L0)[axis]. Rounded to + # int because the request and downstream downsamplers + # are integer-valued. + ds_z = max(1, int(round(scale_ln[0] / max(scale_l0[0], 1e-12)))) + ds_y = max(1, int(round(scale_ln[1] / max(scale_l0[1], 1e-12)))) + ds_x = max(1, int(round(scale_ln[2] / max(scale_l0[2], 1e-12)))) + level_records.append((lvl_int, ds_x, ds_y, ds_z)) + + if not level_records: + raise ValueError("no usable pyramid levels in metadata") + + req_xy = max(1, int(self.dsxy)) + req_z = max(1, int(self.dsz)) + + # Eligibility: every axis's ds must be ≤ the request, so the + # remaining downsampling can be done in software without + # ever needing to upsample. With metadata-declared ds (no + # rounding slack), this is the strict comparison we want. + eligible = [ + rec for rec in level_records + if rec[1] <= req_xy + and rec[2] <= req_xy + and rec[3] <= req_z + ] + if not eligible: + # Request is finer than even L0. Pick L0 and pass through. + eligible = [(0, 1, 1, 1)] + + # Score each candidate: (leftover_x * leftover_y * leftover_z). + # Exact match → score=1 (perfect). Lower is better; tiebreak + # by deeper level (cheaper S3 reads). + def _score(rec): + lvl, dsx, dsy, dsz = rec + lo_x = max(1, req_xy // max(dsx, 1)) + lo_y = max(1, req_xy // max(dsy, 1)) + lo_z = max(1, req_z // max(dsz, 1)) + return (lo_x * lo_y * lo_z, -lvl) + + best_lvl, dsx_lvl, dsy_lvl, dsz_lvl = min(eligible, key=_score) + + # ``leftovers`` tuple convention: (unused, leftover_xy, leftover_z). + # leftover_xy = max(leftover_x, leftover_y) so subsequent code + # that applies a single dsxy-factor never under-downsamples + # either axis. In the typical case dsx==dsy so this is just + # ``req_xy // dsx``. + leftover_x = max(1, req_xy // max(dsx_lvl, 1)) + leftover_y = max(1, req_xy // max(dsy_lvl, 1)) + leftover_z = max(1, req_z // max(dsz_lvl, 1)) + leftover_xy = max(leftover_x, leftover_y) + leftovers = (leftover_xy, leftover_xy, leftover_z) + return best_lvl, leftovers + except Exception as e: + # Legacy fallback: assume isotropic 2**level pyramid. Safe + # for unit tests + any pyramid with that structure that + # lacks parseable OME-zarr metadata. + print( + f"[OverlapDetection] choose_zarr_level falling back to " + f"legacy isotropic picker (metadata not readable: {e!r})" + ) + lvl_xy = self.floor_log2(self.dsxy) + lvl_z = self.floor_log2(self.dsz) + best = min(lvl_xy, lvl_z, max_level) + factor = 1 << best + leftovers = ( + max(1, self.dsxy // factor), + max(1, self.dsxy // factor), + max(1, self.dsz // factor), + ) + return best, leftovers + def _per_axis_pyramid_scale(self, level): + """Return (sx, sy, sz) — the level→L0 voxel-grid scale per axis. + + Reads the OME-zarr ``coordinateTransformations.scale`` metadata + (the writer's declared per-axis ds) — same source of truth as + ``choose_zarr_level``. Falls back to isotropic ``2 ** level`` if + the parent zarr's metadata can't be parsed. + """ + if level <= 0: + return 1.0, 1.0, 1.0 + try: + root = zarr.open(self.prefix, mode='r') + scale_l0 = self._ome_scale_zyx(root, "0") + scale_ln = self._ome_scale_zyx(root, str(level)) + if scale_l0 is None or scale_ln is None: + raise ValueError("scale metadata missing") + sz = max(1.0, scale_ln[0] / max(scale_l0[0], 1e-12)) + sy = max(1.0, scale_ln[1] / max(scale_l0[1], 1e-12)) + sx = max(1.0, scale_ln[2] / max(scale_l0[2], 1e-12)) + return float(sx), float(sy), float(sz) + except Exception as e: + print( + f"[OverlapDetection] _per_axis_pyramid_scale fallback to " + f"isotropic 2**{level}: metadata not readable ({e!r})" + ) + s = float(2 ** level) + return s, s, s + + @staticmethod + def _ome_scale_zyx(root_group, level_name: str): + """Return (scale_z, scale_y, scale_x) from OME-zarr multiscales metadata. + + Reads ``coordinateTransformations[type==scale]`` for the named + level. Slices the trailing ZYX entries from a 3- or 5-axis + scale declaration. Returns ``None`` when the metadata is missing + or malformed — caller falls back to legacy heuristic. + """ + try: + attrs = root_group.attrs.asdict() + multiscales = attrs.get("multiscales", []) + if not multiscales: + return None + for d in multiscales[0].get("datasets", []): + if str(d.get("path")) != str(level_name): + continue + for ct in d.get("coordinateTransformations", []): + if ct.get("type") == "scale": + s = ct.get("scale", []) + if len(s) == 5: + return float(s[2]), float(s[3]), float(s[4]) + if len(s) == 3: + return float(s[0]), float(s[1]), float(s[2]) + return None + except Exception: + return None + return None + def affine_with_half_pixel_shift(self, sx, sy, sz): """ Build a 4x4 scaling affine that also shifts by 0.5·(scale-1) per axis so voxel centers stay aligned after @@ -239,82 +449,129 @@ def find_overlapping_area(self): """ Compute XY Z overlap intervals against every other view, accounting for mipmap/downsampling and per-view affine transforms """ + is_split = 'crop_min' in self.image_loader_df.columns + for i, row_i in self.image_loader_df.iterrows(): view_id = f"timepoint: {row_i['timepoint']}, setup: {row_i['view_setup']}" # get inverted matrice of downsampling - all_intervals = [] + all_intervals = [] if self.file_type == 'zarr': level, leftovers = self.choose_zarr_level() - dim_base = self.load_image_metadata(self.prefix + row_i['file_path'] + f'/{0}') - - # isotropic pyramid - s = float(2 ** level) - mipmap_of_downsample = self.affine_with_half_pixel_shift(s, s, s) - - # TODO - update mipmap with leftovers if other than 1 + print(f"[OverlapDetection] view={view_id}: chosen level={level}, leftovers={leftovers}, dsxy={self.dsxy}, dsz={self.dsz}") + + if is_split: + dim_base = self._split_tile_shape(row_i) + else: + dim_base = self.load_image_metadata(self.prefix) + + # Per-axis pyramid scale. The legacy ``s = 2 ** level`` + # is wrong for anisotropic pyramids that preserve one + # axis at coarse levels (e.g. HCR_823476_s5 keeps Z at + # full-res through L2 while halving XY). Read the actual + # shape ratio from the parent zarr; fall back to + # isotropic on lookup failure. + sx, sy, sz = self._per_axis_pyramid_scale(level) + + mipmap_of_downsample = self.affine_with_half_pixel_shift(sx, sy, sz) + + # leftovers are returned by ``choose_zarr_level`` as + # (ds_x, ds_y, ds_z) — what remains to be applied as + # software downsampling on top of the chosen level. _, dsxy, dsz = leftovers elif self.file_type == 'tiff': - dim_base = self.load_image_metadata(self.prefix + row_i['file_path']) + dim_base = self.load_image_metadata(os.path.join(self.prefix, row_i['file_path'])) mipmap_of_downsample = self.create_mipmap_transform() dsxy, dsz = self.dsxy, self.dsz level = None downsampled_dim_base = self.open_and_downsample(dim_base, dsxy, dsz) - t1 = self.get_inverse_mipmap_transform(mipmap_of_downsample) - - # compare with all view_ids - for j, row_j in self.image_loader_df.iterrows(): - if i == j: continue - - view_id_other = f"timepoint: {row_j['timepoint']}, setup: {row_j['view_setup']}" - - if self.file_type == 'zarr': - dim_other = self.load_image_metadata(self.prefix + row_j['file_path'] + f'/{0}') - elif self.file_type == 'tiff': - dim_other = self.load_image_metadata(self.prefix + row_j['file_path']) - - # get transforms matrix from both view_ids and downsampling matrices - matrix = self.transform_models.get(view_id) - matrix_other = self.transform_models.get(view_id_other) - - if self.file_type == 'zarr': - s = float(2 ** level) - mipmap_of_downsample_other = self.affine_with_half_pixel_shift(s, s, s) - elif self.file_type == 'tiff': - mipmap_of_downsample_other = self.create_mipmap_transform() - - inverse_mipmap_of_downsample_other = self.get_inverse_mipmap_transform(mipmap_of_downsample_other) - inverse_matrix = self.get_inverse_mipmap_transform(matrix) - - concatenated_matrix = np.dot(inverse_matrix, matrix_other) - t2 = np.dot(inverse_mipmap_of_downsample_other, concatenated_matrix) - - intervals = self.estimate_bounds(t1, dim_base) - intervals_other = self.estimate_bounds(t2, dim_other) - - bounding_boxes = tuple(map(lambda x: np.round(x).astype(int), intervals)) - bounding_boxes_other = tuple(map(lambda x: np.round(x).astype(int), intervals_other)) - - # find upper and lower bounds of intersection - if np.all((bounding_boxes[1] >= bounding_boxes_other[0]) & (bounding_boxes_other[1] >= bounding_boxes[0])): - intersected_boxes = self.calculate_intersection(bounding_boxes, bounding_boxes_other) - intersect = self.calculate_intersection(downsampled_dim_base, intersected_boxes) - intersect_dict = { - 'lower_bound': intersect[0], - 'upper_bound': intersect[1], - 'span': self.calculate_new_dims(intersect[0], intersect[1]) - } - - lb, ub = intersect[0], intersect[1] + t1 = self.get_inverse_mipmap_transform(mipmap_of_downsample) + + if self.overlapping_only: + # compare with all view_ids + for j, row_j in self.image_loader_df.iterrows(): + if i == j: continue + + view_id_other = f"timepoint: {row_j['timepoint']}, setup: {row_j['view_setup']}" + + if self.file_type == 'zarr': + if is_split: + dim_other = self._split_tile_shape(row_j) + else: + dim_other = self.load_image_metadata(self.prefix) + elif self.file_type == 'tiff': + dim_other = self.load_image_metadata(os.path.join(self.prefix, row_j['file_path'])) + + # get transforms matrix from both view_ids and downsampling matrices + matrix = self.transform_models.get(view_id) + matrix_other = self.transform_models.get(view_id_other) + + if self.file_type == 'zarr': + s = float(2 ** level) + mipmap_of_downsample_other = self.affine_with_half_pixel_shift(s, s, s) + elif self.file_type == 'tiff': + mipmap_of_downsample_other = self.create_mipmap_transform() + + inverse_mipmap_of_downsample_other = self.get_inverse_mipmap_transform(mipmap_of_downsample_other) + inverse_matrix = self.get_inverse_mipmap_transform(matrix) + + concatenated_matrix = np.dot(inverse_matrix, matrix_other) + t2 = np.dot(inverse_mipmap_of_downsample_other, concatenated_matrix) + + intervals = self.estimate_bounds(t1, dim_base) + intervals_other = self.estimate_bounds(t2, dim_other) + + bounding_boxes = tuple(map(lambda x: np.round(x).astype(int), intervals)) + bounding_boxes_other = tuple(map(lambda x: np.round(x).astype(int), intervals_other)) + + # find upper and lower bounds of intersection + if np.all((bounding_boxes[1] >= bounding_boxes_other[0]) & (bounding_boxes_other[1] >= bounding_boxes[0])): + intersected_boxes = self.calculate_intersection(bounding_boxes, bounding_boxes_other) + intersect = self.calculate_intersection(downsampled_dim_base, intersected_boxes) + intersect_dict = { + 'lower_bound': intersect[0], + 'upper_bound': intersect[1], + 'span': self.calculate_new_dims(intersect[0], intersect[1]) + } + + lb, ub = intersect[0], intersect[1] + sz = self.size_interval(lb, ub) + if sz > self.max_interval_size: + self.max_interval_size = sz + + # add max size + all_intervals.append(intersect_dict) + + # Single-view dataset: no pairwise overlaps exist, so use the + # full downsampled volume as the processing region. + if not all_intervals and len(self.image_loader_df) == 1: + lb = np.array(downsampled_dim_base[0]) + ub = np.array(downsampled_dim_base[1]) + all_intervals.append({ + 'lower_bound': lb, + 'upper_bound': ub, + 'span': self.calculate_new_dims(lb, ub), + }) sz = self.size_interval(lb, ub) if sz > self.max_interval_size: self.max_interval_size = sz - # add max size - all_intervals.append(intersect_dict) - + else: + # Full-volume mode: use the entire downsampled tile as the + # processing region (for registration, not stitching). + lb = np.array(downsampled_dim_base[0]) + ub = np.array(downsampled_dim_base[1]) + all_intervals.append({ + 'lower_bound': lb, + 'upper_bound': ub, + 'span': self.calculate_new_dims(lb, ub), + }) + sz = self.size_interval(lb, ub) + if sz > self.max_interval_size: + self.max_interval_size = sz + self.to_process[view_id] = all_intervals return dsxy, dsz, level, mipmap_of_downsample diff --git a/Rhapso/detection/save_interest_points.py b/Rhapso/detection/save_interest_points.py index 4f0192a..6136c7f 100644 --- a/Rhapso/detection/save_interest_points.py +++ b/Rhapso/detection/save_interest_points.py @@ -6,6 +6,7 @@ from io import BytesIO import io import json +import time """ Save Interest Points saves interest points as N5 and updates the xml with pathways @@ -242,8 +243,19 @@ def save_points(self): for _, row in self.image_loader_df.iterrows(): view_id = f"timepoint: {row['timepoint']}, setup: {row['view_setup']}" n5_path = f"interestpoints.n5/tpId_{row['timepoint']}_viewSetupId_{row['view_setup']}/beads" - self.save_interest_points_to_n5(view_id, n5_path) - self.save_intensities_to_n5(view_id, n5_path) + # Retry on FileNotFoundError: zarr's write-then-rename pattern can fail + # transiently on network filesystems (e.g. /scratch/ on EFS/NFS) where + # the temp file's directory entry isn't immediately visible to the rename + # syscall. Both save functions delete before writing, so retries are safe. + for attempt in range(3): + try: + self.save_interest_points_to_n5(view_id, n5_path) + self.save_intensities_to_n5(view_id, n5_path) + break + except FileNotFoundError: + if attempt == 2: + raise + time.sleep(0.5) path = self.n5_output_file_prefix + "interestpoints.n5" diff --git a/Rhapso/evaluation/__init__.py b/Rhapso/evaluation/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/Rhapso/evaluation/detection_qc/__init__.py b/Rhapso/evaluation/detection_qc/__init__.py new file mode 100644 index 0000000..583cd07 --- /dev/null +++ b/Rhapso/evaluation/detection_qc/__init__.py @@ -0,0 +1,23 @@ +""" +Detection QC — Quality control utilities for interest point detection. + +Provides per-view metrics computation, parameter sweep analysis, +and diagnostic plotting for IP detection outputs. +""" +from Rhapso.evaluation.detection_qc.view_metrics import ( + ViewIPMetrics, + compute_view_metrics, + compute_all_view_metrics, +) +from Rhapso.evaluation.detection_qc.sweep_analyzer import ( + SweepTrialResult, + SweepAnalyzer, +) + +__all__ = [ + "ViewIPMetrics", + "compute_view_metrics", + "compute_all_view_metrics", + "SweepTrialResult", + "SweepAnalyzer", +] diff --git a/Rhapso/evaluation/detection_qc/plotting.py b/Rhapso/evaluation/detection_qc/plotting.py new file mode 100644 index 0000000..daf167b --- /dev/null +++ b/Rhapso/evaluation/detection_qc/plotting.py @@ -0,0 +1,345 @@ +""" +Detection QC plotting — diagnostic figures for IP detection parameter sweeps. + +All plotting functions accept an output_dir and save PNGs directly. +The caller (R2R step or capsule) is responsible for setting output_dir +to the appropriate location (e.g. /results/ on Code Ocean). +""" +import logging +import os +from typing import List, Optional + +import matplotlib + +matplotlib.use("Agg") # headless backend +import matplotlib.pyplot as plt +import numpy as np + +from Rhapso.evaluation.detection_qc.sweep_analyzer import SweepTrialResult + +logger = logging.getLogger(__name__) + +# Colorblind-safe palette (IBM Design) +COLORS_PASS = "#648FFF" +COLORS_FAIL = "#DC267F" +COLORS_NEUTRAL = "#785EF0" +COLORS_SELECTED = "#FE6100" + + +def plot_sweep_ip_counts( + trials: List[SweepTrialResult], + target_interest_points: int, + output_dir: str, + filename: str = "sweep_ip_counts.png", +) -> str: + """Bar chart of mean IP count per trial, colored by pass/fail. + + Parameters + ---------- + trials : list of SweepTrialResult + All sweep trials. + target_interest_points : int + Target threshold (drawn as horizontal line). + output_dir : str + Directory to save the figure. + filename : str + Output filename. + + Returns + ------- + str + Full path to the saved figure. + """ + os.makedirs(output_dir, exist_ok=True) + + fig, ax = plt.subplots(figsize=(max(6, len(trials) * 0.8), 5)) + + labels = [] + means = [] + colors = [] + for t in trials: + labels.append(f"s{t.multiscale}\n\u03c3={t.sigma}") + means.append(t.mean_ip_count) + colors.append(COLORS_PASS if t.success else COLORS_FAIL) + + x = np.arange(len(trials)) + bars = ax.bar(x, means, color=colors, edgecolor="black", linewidth=0.5) + + ax.axhline( + y=target_interest_points, + color="black", + linestyle="--", + linewidth=1.0, + label=f"Target ({target_interest_points})", + ) + + ax.set_xticks(x) + ax.set_xticklabels(labels, fontsize=8) + ax.set_ylabel("Mean IP Count per View") + ax.set_xlabel("Trial (scale / sigma)") + ax.set_title("Parameter Sweep: Mean IP Count by Trial") + ax.legend(loc="upper right") + + # Annotate pass/fail + for i, (bar, trial) in enumerate(zip(bars, trials)): + rate_text = f"{trial.success_rate * 100:.0f}%" + ax.text( + bar.get_x() + bar.get_width() / 2, + bar.get_height() + target_interest_points * 0.02, + rate_text, + ha="center", + va="bottom", + fontsize=7, + fontweight="bold", + ) + + fig.tight_layout() + output_path = os.path.join(output_dir, filename) + fig.savefig(output_path, dpi=150) + plt.close(fig) + logger.info(f"Saved sweep IP counts plot to {output_path}") + return output_path + + +def plot_sweep_success_rates( + trials: List[SweepTrialResult], + output_dir: str, + filename: str = "sweep_success_rates.png", +) -> str: + """Bar chart of view success rates per trial. + + Parameters + ---------- + trials : list of SweepTrialResult + All sweep trials. + output_dir : str + Directory to save the figure. + filename : str + Output filename. + + Returns + ------- + str + Full path to the saved figure. + """ + os.makedirs(output_dir, exist_ok=True) + + fig, ax = plt.subplots(figsize=(max(6, len(trials) * 0.8), 4)) + + labels = [f"s{t.multiscale}\n\u03c3={t.sigma}" for t in trials] + rates = [t.success_rate * 100 for t in trials] + colors = [COLORS_PASS if t.success else COLORS_FAIL for t in trials] + + x = np.arange(len(trials)) + ax.bar(x, rates, color=colors, edgecolor="black", linewidth=0.5) + + ax.axhline(y=50, color="black", linestyle="--", linewidth=1.0, label="50% threshold") + + ax.set_xticks(x) + ax.set_xticklabels(labels, fontsize=8) + ax.set_ylabel("Views Meeting Target (%)") + ax.set_xlabel("Trial (scale / sigma)") + ax.set_title("Parameter Sweep: View Success Rate") + ax.set_ylim(0, 105) + ax.legend(loc="upper right") + + fig.tight_layout() + output_path = os.path.join(output_dir, filename) + fig.savefig(output_path, dpi=150) + plt.close(fig) + logger.info(f"Saved sweep success rates plot to {output_path}") + return output_path + + +def plot_per_view_ip_distribution( + trials: List[SweepTrialResult], + target_interest_points: int, + output_dir: str, + filename: str = "per_view_ip_distribution.png", +) -> str: + """Box plot of IP counts per view for each trial. + + Parameters + ---------- + trials : list of SweepTrialResult + All sweep trials (must have view_metrics populated). + target_interest_points : int + Target threshold line. + output_dir : str + Directory to save the figure. + filename : str + Output filename. + + Returns + ------- + str + Full path to the saved figure. + """ + os.makedirs(output_dir, exist_ok=True) + + fig, ax = plt.subplots(figsize=(max(6, len(trials) * 0.8), 5)) + + data = [] + labels = [] + for t in trials: + counts = [vm.ip_count for vm in t.view_metrics] + data.append(counts if counts else [0]) + labels.append(f"s{t.multiscale}\n\u03c3={t.sigma}") + + bp = ax.boxplot( + data, + patch_artist=True, + labels=labels, + medianprops=dict(color="black", linewidth=1.5), + ) + + for i, (patch, trial) in enumerate(zip(bp["boxes"], trials)): + patch.set_facecolor(COLORS_PASS if trial.success else COLORS_FAIL) + patch.set_alpha(0.6) + + ax.axhline( + y=target_interest_points, + color="black", + linestyle="--", + linewidth=1.0, + label=f"Target ({target_interest_points})", + ) + + ax.set_ylabel("IP Count per View") + ax.set_xlabel("Trial (scale / sigma)") + ax.set_title("Parameter Sweep: IP Count Distribution per View") + ax.legend(loc="upper right") + + fig.tight_layout() + output_path = os.path.join(output_dir, filename) + fig.savefig(output_path, dpi=150) + plt.close(fig) + logger.info(f"Saved per-view IP distribution plot to {output_path}") + return output_path + + +def plot_sigma_vs_multiscale_heatmap( + trials: List[SweepTrialResult], + output_dir: str, + filename: str = "sigma_multiscale_heatmap.png", +) -> str: + """Heatmap of mean IP count across sigma x multiscale grid. + + Parameters + ---------- + trials : list of SweepTrialResult + All sweep trials. + output_dir : str + Directory to save the figure. + filename : str + Output filename. + + Returns + ------- + str + Full path to the saved figure. + """ + os.makedirs(output_dir, exist_ok=True) + + # Build unique sorted axes + multiscales = sorted(set(t.multiscale for t in trials)) + sigmas = sorted(set(t.sigma for t in trials), reverse=True) + + if len(multiscales) < 2 and len(sigmas) < 2: + logger.info("Skipping heatmap — fewer than 2 unique values on both axes") + return "" + + grid = np.full((len(sigmas), len(multiscales)), np.nan) + success_grid = np.full((len(sigmas), len(multiscales)), False) + + ms_idx = {ms: i for i, ms in enumerate(multiscales)} + sig_idx = {s: i for i, s in enumerate(sigmas)} + + for t in trials: + r = sig_idx[t.sigma] + c = ms_idx[t.multiscale] + grid[r, c] = t.mean_ip_count + success_grid[r, c] = t.success + + fig, ax = plt.subplots(figsize=(max(5, len(multiscales) * 1.2), max(4, len(sigmas) * 0.8))) + + im = ax.imshow(grid, cmap="YlOrRd", aspect="auto") + cbar = fig.colorbar(im, ax=ax, label="Mean IP Count") + + ax.set_xticks(range(len(multiscales))) + ax.set_xticklabels([f"s{ms}" for ms in multiscales]) + ax.set_yticks(range(len(sigmas))) + ax.set_yticklabels([f"\u03c3={s}" for s in sigmas]) + ax.set_xlabel("Multiscale Level") + ax.set_ylabel("Sigma") + ax.set_title("Parameter Sweep: Mean IP Count Heatmap") + + # Annotate cells + for r in range(len(sigmas)): + for c in range(len(multiscales)): + val = grid[r, c] + if np.isnan(val): + continue + marker = "\u2713" if success_grid[r, c] else "\u2717" + ax.text( + c, r, f"{val:.0f}\n{marker}", + ha="center", va="center", fontsize=8, + color="white" if val > np.nanmax(grid) * 0.6 else "black", + ) + + fig.tight_layout() + output_path = os.path.join(output_dir, filename) + fig.savefig(output_path, dpi=150) + plt.close(fig) + logger.info(f"Saved sigma vs multiscale heatmap to {output_path}") + return output_path + + +def generate_all_plots( + trials: List[SweepTrialResult], + target_interest_points: int, + output_dir: str, +) -> List[str]: + """Generate all detection QC plots. + + Parameters + ---------- + trials : list of SweepTrialResult + All sweep trials. + target_interest_points : int + IP count target threshold. + output_dir : str + Directory to save all figures. + + Returns + ------- + list of str + Paths to all generated plot files. + """ + if not trials: + logger.warning("No trials to plot") + return [] + + paths = [] + + path = plot_sweep_ip_counts(trials, target_interest_points, output_dir) + if path: + paths.append(path) + + path = plot_sweep_success_rates(trials, output_dir) + if path: + paths.append(path) + + # Only generate per-view plots if view metrics are populated + has_view_metrics = any(t.view_metrics for t in trials) + if has_view_metrics: + path = plot_per_view_ip_distribution(trials, target_interest_points, output_dir) + if path: + paths.append(path) + + path = plot_sigma_vs_multiscale_heatmap(trials, output_dir) + if path: + paths.append(path) + + logger.info(f"Generated {len(paths)} QC plots in {output_dir}") + return paths diff --git a/Rhapso/evaluation/detection_qc/sweep_analyzer.py b/Rhapso/evaluation/detection_qc/sweep_analyzer.py new file mode 100644 index 0000000..e121340 --- /dev/null +++ b/Rhapso/evaluation/detection_qc/sweep_analyzer.py @@ -0,0 +1,275 @@ +""" +Parameter sweep analysis for interest point detection QC. + +Aggregates per-view metrics across sweep trials and produces a +summary comparing different (multiscale, sigma) parameter combinations. + +The sweep summary uses a labeled metrics format where each metric is +a dict with 'name', 'value', 'description', and optional context keys. +""" +import logging +from dataclasses import dataclass, field +from typing import List, Optional + +from Rhapso.evaluation.detection_qc.view_metrics import ViewIPMetrics + +logger = logging.getLogger(__name__) + + +@dataclass +class SweepTrialResult: + """Results from a single parameter combination trial. + + Parameters + ---------- + multiscale : str + Multiscale level tried (e.g. "3"). + sigma : float + Sigma value tried. + trial_index : int + Order in which this trial was attempted (0-indexed). + success : bool + Whether >= 50% of views met the target. + n5_output_path : str + Path where detection output was written. + view_metrics : list of ViewIPMetrics + Per-view QC metrics for this trial. + """ + + multiscale: str + sigma: float + trial_index: int + success: bool + n5_output_path: str + view_metrics: List[ViewIPMetrics] = field(default_factory=list) + + @property + def total_ip_count(self) -> int: + """Total interest points across all views.""" + return sum(vm.ip_count for vm in self.view_metrics) + + @property + def views_meeting_target(self) -> int: + """Number of views that met the IP target.""" + return sum(1 for vm in self.view_metrics if vm.meets_target) + + @property + def num_views(self) -> int: + """Number of views analyzed.""" + return len(self.view_metrics) + + @property + def success_rate(self) -> float: + """Fraction of views meeting target (0.0 to 1.0).""" + if not self.view_metrics: + return 0.0 + return self.views_meeting_target / len(self.view_metrics) + + @property + def mean_ip_count(self) -> float: + """Mean IP count per view.""" + if not self.view_metrics: + return 0.0 + return self.total_ip_count / len(self.view_metrics) + + @property + def mean_density(self) -> float: + """Mean spatial density across views.""" + if not self.view_metrics: + return 0.0 + return sum(vm.density for vm in self.view_metrics) / len(self.view_metrics) + + def to_metric_list(self) -> list: + """Serialize trial-level metrics as labeled dicts. + + Returns + ------- + list of dict + Each dict has 'name', 'value', 'description' and context keys. + """ + return [ + { + "name": "multiscale", + "value": self.multiscale, + "description": "Zarr pyramid level used for detection", + "trial_index": self.trial_index, + }, + { + "name": "sigma", + "value": self.sigma, + "description": "DoG sigma parameter for blob detection", + "trial_index": self.trial_index, + }, + { + "name": "success", + "value": self.success, + "description": "Whether >= 50% of views met the IP target", + "trial_index": self.trial_index, + }, + { + "name": "num_views", + "value": self.num_views, + "description": "Number of tile views analyzed", + "trial_index": self.trial_index, + }, + { + "name": "views_meeting_target", + "value": self.views_meeting_target, + "description": "Number of views with IP count >= target", + "trial_index": self.trial_index, + }, + { + "name": "success_rate", + "value": round(self.success_rate, 4), + "description": "Fraction of views meeting target (0.0-1.0)", + "trial_index": self.trial_index, + }, + { + "name": "total_ip_count", + "value": self.total_ip_count, + "description": "Total interest points across all views", + "trial_index": self.trial_index, + }, + { + "name": "mean_ip_count", + "value": round(self.mean_ip_count, 2), + "description": "Mean interest points per view", + "trial_index": self.trial_index, + }, + { + "name": "mean_density", + "value": round(self.mean_density, 8), + "description": "Mean spatial density (IPs per unit volume)", + "trial_index": self.trial_index, + }, + ] + + def to_dict(self) -> dict: + """Full serialization including per-view metrics.""" + return { + "multiscale": self.multiscale, + "sigma": self.sigma, + "trial_index": self.trial_index, + "success": self.success, + "n5_output_path": self.n5_output_path, + "metrics": self.to_metric_list(), + "view_metrics": [vm.to_dict() for vm in self.view_metrics], + } + + +class SweepAnalyzer: + """Analyze a collection of sweep trial results. + + Parameters + ---------- + trials : list of SweepTrialResult + All trials from the parameter sweep. + target_interest_points : int + The IP target used during the sweep. + """ + + def __init__( + self, + trials: List[SweepTrialResult], + target_interest_points: int, + ): + self._trials = trials + self._target_interest_points = target_interest_points + + def get_selected_trial(self) -> Optional[SweepTrialResult]: + """Return the first successful trial, or None if all failed.""" + for trial in self._trials: + if trial.success: + return trial + return None + + def get_summary(self) -> dict: + """Produce a sweep summary with labeled metrics. + + Returns + ------- + dict + Summary with 'summary_metrics' (list of labeled dicts) + and 'trials' (list of per-trial dicts). + """ + selected = self.get_selected_trial() + num_succeeded = sum(1 for t in self._trials if t.success) + + summary_metrics = [ + { + "name": "target_interest_points", + "value": self._target_interest_points, + "description": "IP count threshold for each view to pass", + }, + { + "name": "num_trials_attempted", + "value": len(self._trials), + "description": "Total parameter combinations tested", + }, + { + "name": "num_trials_succeeded", + "value": num_succeeded, + "description": "Trials where >= 50% of views met target", + }, + { + "name": "first_success_trial_index", + "value": selected.trial_index if selected else None, + "description": "Index of first successful trial (None if all failed)", + }, + { + "name": "selected_multiscale", + "value": selected.multiscale if selected else None, + "description": "Multiscale level of selected (first successful) trial", + }, + { + "name": "selected_sigma", + "value": selected.sigma if selected else None, + "description": "Sigma of selected (first successful) trial", + }, + ] + + if selected: + summary_metrics.extend([ + { + "name": "selected_total_ip_count", + "value": selected.total_ip_count, + "description": "Total IPs in the selected trial", + }, + { + "name": "selected_success_rate", + "value": round(selected.success_rate, 4), + "description": "View success rate of the selected trial", + }, + { + "name": "selected_mean_ip_count", + "value": round(selected.mean_ip_count, 2), + "description": "Mean IPs per view in the selected trial", + }, + ]) + + return { + "summary_metrics": summary_metrics, + "trials": [t.to_dict() for t in self._trials], + } + + def get_all_view_metrics_flat(self) -> list: + """Collect all view metrics across all trials as a flat list. + + Useful for cross-trial comparison and plotting. + + Returns + ------- + list of dict + Each dict has trial context (multiscale, sigma, trial_index) + merged with the view metric dict. + """ + flat = [] + for trial in self._trials: + for vm in trial.view_metrics: + entry = vm.to_dict() + entry["trial_multiscale"] = trial.multiscale + entry["trial_sigma"] = trial.sigma + entry["trial_index"] = trial.trial_index + entry["trial_success"] = trial.success + flat.append(entry) + return flat diff --git a/Rhapso/evaluation/detection_qc/view_metrics.py b/Rhapso/evaluation/detection_qc/view_metrics.py new file mode 100644 index 0000000..a0dc886 --- /dev/null +++ b/Rhapso/evaluation/detection_qc/view_metrics.py @@ -0,0 +1,431 @@ +""" +Per-view interest point detection QC metrics. + +Reads N5 interest point outputs and computes spatial distribution, +density, and intensity statistics for each view. + +N5 output structure (written by Rhapso SaveInterestPoints): + {n5_base}/tpId_{tp}_viewSetupId_{vs}/beads/interestpoints/ + id/ — uint64 sequential IDs + loc/ — float64 (N, 3) XYZ locations + intensities/ — float32 (N,) values + +Also reads attributes.json for fast IP count (dimensions field). +""" +import json +import logging +from dataclasses import dataclass +from pathlib import Path +from typing import List, Optional + +import numpy as np +import zarr + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class ViewIPMetrics: + """QC metrics for interest points detected in a single view. + + Parameters + ---------- + view_id : str + Directory name, e.g. "tpId_0_viewSetupId_5". + ip_count : int + Total number of interest points detected. + spatial_extent_xyz : tuple + Range (max - min) of IP coordinates in X, Y, Z. + spatial_std_xyz : tuple + Standard deviation of IP coordinates in X, Y, Z. + density : float + Points per unit volume (bounding box volume). + intensity_mean : float + Mean intensity of detected IPs. + intensity_std : float + Standard deviation of IP intensities. + intensity_min : float + Minimum IP intensity. + intensity_max : float + Maximum IP intensity. + meets_target : bool + Whether ip_count >= the target threshold. + """ + + view_id: str + ip_count: int + spatial_extent_xyz: tuple + spatial_std_xyz: tuple + density: float + intensity_mean: float + intensity_std: float + intensity_min: float + intensity_max: float + meets_target: bool + + def to_dict(self) -> dict: + """Serialize to JSON-compatible dict.""" + return { + "view_id": self.view_id, + "ip_count": self.ip_count, + "spatial_extent_xyz": list(self.spatial_extent_xyz), + "spatial_std_xyz": list(self.spatial_std_xyz), + "density": round(self.density, 8), + "intensity_mean": round(self.intensity_mean, 4), + "intensity_std": round(self.intensity_std, 4), + "intensity_min": round(self.intensity_min, 4), + "intensity_max": round(self.intensity_max, 4), + "meets_target": self.meets_target, + } + + def to_metric_list(self) -> list: + """Serialize as a list of labeled metric dicts. + + Each entry has 'name', 'value', and 'description' keys, + plus optional extra keys for context. + + Returns + ------- + list of dict + Metrics with labeled keys suitable for JSON reporting. + """ + return [ + { + "name": "ip_count", + "value": self.ip_count, + "description": "Total interest points detected in this view", + "view_id": self.view_id, + }, + { + "name": "spatial_extent_x", + "value": round(self.spatial_extent_xyz[0], 2), + "description": "Range of IP X coordinates (max - min, pixels)", + "view_id": self.view_id, + }, + { + "name": "spatial_extent_y", + "value": round(self.spatial_extent_xyz[1], 2), + "description": "Range of IP Y coordinates (max - min, pixels)", + "view_id": self.view_id, + }, + { + "name": "spatial_extent_z", + "value": round(self.spatial_extent_xyz[2], 2), + "description": "Range of IP Z coordinates (max - min, pixels)", + "view_id": self.view_id, + }, + { + "name": "spatial_std_x", + "value": round(self.spatial_std_xyz[0], 4), + "description": "Std dev of IP X coordinates (pixels)", + "view_id": self.view_id, + }, + { + "name": "spatial_std_y", + "value": round(self.spatial_std_xyz[1], 4), + "description": "Std dev of IP Y coordinates (pixels)", + "view_id": self.view_id, + }, + { + "name": "spatial_std_z", + "value": round(self.spatial_std_xyz[2], 4), + "description": "Std dev of IP Z coordinates (pixels)", + "view_id": self.view_id, + }, + { + "name": "density", + "value": round(self.density, 8), + "description": "Interest points per unit volume (bounding box)", + "view_id": self.view_id, + }, + { + "name": "intensity_mean", + "value": round(self.intensity_mean, 4), + "description": "Mean intensity of detected interest points", + "view_id": self.view_id, + }, + { + "name": "intensity_std", + "value": round(self.intensity_std, 4), + "description": "Standard deviation of IP intensities", + "view_id": self.view_id, + }, + { + "name": "intensity_min", + "value": round(self.intensity_min, 4), + "description": "Minimum IP intensity", + "view_id": self.view_id, + }, + { + "name": "intensity_max", + "value": round(self.intensity_max, 4), + "description": "Maximum IP intensity", + "view_id": self.view_id, + }, + { + "name": "meets_target", + "value": self.meets_target, + "description": "Whether IP count meets the target threshold", + "view_id": self.view_id, + }, + ] + + +def _get_ip_count_from_attributes(attrs_path: Path) -> Optional[int]: + """Read IP count from attributes.json (fast path, no array loading). + + Parameters + ---------- + attrs_path : Path + Path to attributes.json inside interestpoints/id/. + + Returns + ------- + int or None + IP count from the dimensions field, or None on error. + """ + try: + with open(attrs_path, "r") as f: + attrs = json.load(f) + return attrs.get("dimensions", [0])[-1] + except (json.JSONDecodeError, OSError) as e: + logger.warning(f"Could not read attributes from {attrs_path}: {e}") + return None + + +def _read_n5_array(n5_store_path: str, dataset_rel_path: str) -> Optional[np.ndarray]: + """Read an array from an N5 store. + + Parameters + ---------- + n5_store_path : str + Path to the .n5 directory or parent directory containing N5 data. + dataset_rel_path : str + Relative path within the store to the dataset. + + Returns + ------- + np.ndarray or None + The array data, or None if not found. + """ + try: + store = zarr.N5Store(n5_store_path) + root = zarr.open(store, mode="r") + if dataset_rel_path in root: + return root[dataset_rel_path][:] + return None + except Exception as e: + logger.warning(f"Could not read {dataset_rel_path} from {n5_store_path}: {e}") + return None + + +def _find_n5_store_and_rel_path(full_path: str) -> Optional[tuple]: + """Split a full path into N5 store path and relative dataset path. + + Parameters + ---------- + full_path : str + Full path like '/scratch/.../interestpoints.n5/tpId_0_.../beads/interestpoints/loc'. + + Returns + ------- + tuple of (str, str) or None + (n5_store_path, dataset_rel_path), or None if no .n5 found. + """ + parts = full_path.split("/") + for i, part in enumerate(parts): + if part.endswith(".n5"): + store_path = "/".join(parts[: i + 1]) + rel_path = "/".join(parts[i + 1 :]) + return store_path, rel_path + return None + + +def compute_view_metrics( + n5_base_path: str, + view_dir_name: str, + target_interest_points: int, + compute_spatial: bool = True, +) -> Optional[ViewIPMetrics]: + """Compute QC metrics for a single view's IP detection output. + + Parameters + ---------- + n5_base_path : str + Path to the output directory containing tpId_ subdirectories. + May contain an interestpoints.n5 subdirectory, or the tpId_ dirs + may be directly under this path. + view_dir_name : str + Directory name like "tpId_0_viewSetupId_5". + target_interest_points : int + Threshold for the meets_target flag. + compute_spatial : bool + If True, read full loc/intensities arrays for spatial metrics. + If False, only read attributes.json for IP count (faster). + + Returns + ------- + ViewIPMetrics or None + Computed metrics, or None if data could not be read. + """ + base = Path(n5_base_path) + + # Find the interestpoints location — may be under interestpoints.n5/ or beads/ + # Try common patterns + candidates = [ + base / view_dir_name / "beads" / "interestpoints", + base / "interestpoints.n5" / view_dir_name / "beads" / "interestpoints", + ] + + ip_dir = None + for candidate in candidates: + if candidate.exists(): + ip_dir = candidate + break + + if ip_dir is None: + logger.debug(f"No interestpoints directory found for {view_dir_name} in {n5_base_path}") + return None + + # Fast path: get count from attributes.json + attrs_path = ip_dir / "id" / "attributes.json" + ip_count = _get_ip_count_from_attributes(attrs_path) + if ip_count is None: + ip_count = 0 + + if not compute_spatial or ip_count == 0: + return ViewIPMetrics( + view_id=view_dir_name, + ip_count=ip_count, + spatial_extent_xyz=(0.0, 0.0, 0.0), + spatial_std_xyz=(0.0, 0.0, 0.0), + density=0.0, + intensity_mean=0.0, + intensity_std=0.0, + intensity_min=0.0, + intensity_max=0.0, + meets_target=ip_count >= target_interest_points, + ) + + # Full path: read loc and intensities arrays + loc_path = str(ip_dir / "loc") + intensities_path = str(ip_dir / "intensities") + + loc_result = _find_n5_store_and_rel_path(loc_path) + int_result = _find_n5_store_and_rel_path(intensities_path) + + loc_data = None + int_data = None + + if loc_result: + loc_data = _read_n5_array(loc_result[0], loc_result[1]) + if int_result: + int_data = _read_n5_array(int_result[0], int_result[1]) + + # If N5 store approach didn't work, try direct zarr open + if loc_data is None: + try: + store = zarr.N5Store(str(ip_dir.parent.parent.parent / "interestpoints.n5")) + root = zarr.open(store, mode="r") + rel_base = f"{view_dir_name}/beads/interestpoints" + if f"{rel_base}/loc" in root: + loc_data = root[f"{rel_base}/loc"][:] + if f"{rel_base}/intensities" in root: + int_data = root[f"{rel_base}/intensities"][:] + except Exception: + pass + + # Compute spatial metrics from loc array + if loc_data is not None and len(loc_data) > 0: + ip_count = len(loc_data) + extent = np.ptp(loc_data, axis=0) + spatial_extent = tuple(float(v) for v in extent) + spatial_std = tuple(float(v) for v in np.std(loc_data, axis=0)) + volume = float(np.prod(np.maximum(extent, 1.0))) + density = ip_count / volume + else: + spatial_extent = (0.0, 0.0, 0.0) + spatial_std = (0.0, 0.0, 0.0) + density = 0.0 + + # Compute intensity metrics + if int_data is not None and len(int_data) > 0: + intensity_mean = float(np.mean(int_data)) + intensity_std = float(np.std(int_data)) + intensity_min = float(np.min(int_data)) + intensity_max = float(np.max(int_data)) + else: + intensity_mean = 0.0 + intensity_std = 0.0 + intensity_min = 0.0 + intensity_max = 0.0 + + return ViewIPMetrics( + view_id=view_dir_name, + ip_count=ip_count, + spatial_extent_xyz=spatial_extent, + spatial_std_xyz=spatial_std, + density=density, + intensity_mean=intensity_mean, + intensity_std=intensity_std, + intensity_min=intensity_min, + intensity_max=intensity_max, + meets_target=ip_count >= target_interest_points, + ) + + +def compute_all_view_metrics( + n5_base_path: str, + target_interest_points: int, + compute_spatial: bool = True, +) -> List[ViewIPMetrics]: + """Compute QC metrics for all views found under n5_base_path. + + Discovers tpId_*_viewSetupId_* directories automatically by + searching recursively for attributes.json files (matching the + pattern used by R2R's _check_ip_detection_success). + + Parameters + ---------- + n5_base_path : str + Path to the output directory. + target_interest_points : int + Threshold for meets_target flag. + compute_spatial : bool + If True, read full arrays for spatial/intensity metrics. + + Returns + ------- + list of ViewIPMetrics + One entry per discovered view, sorted by view_id. + """ + base = Path(n5_base_path) + metrics = [] + + # Discover view directories by finding attributes.json files + view_dirs = set() + for attrs_file in base.rglob("attributes.json"): + if attrs_file.parent.name != "id": + continue + if "interestpoints" not in attrs_file.parent.parent.name: + continue + # Walk up to find tpId_*_viewSetupId_* directory + for part in attrs_file.parts: + if part.startswith("tpId_") and "viewSetupId" in part: + view_dirs.add(part) + break + + for view_dir in sorted(view_dirs): + result = compute_view_metrics( + n5_base_path=n5_base_path, + view_dir_name=view_dir, + target_interest_points=target_interest_points, + compute_spatial=compute_spatial, + ) + if result is not None: + metrics.append(result) + + logger.info( + f"Computed metrics for {len(metrics)} views from {n5_base_path}" + ) + return metrics diff --git a/Rhapso/matching/load_and_transform_points.py b/Rhapso/matching/load_and_transform_points.py index 2fd92e5..aa079db 100644 --- a/Rhapso/matching/load_and_transform_points.py +++ b/Rhapso/matching/load_and_transform_points.py @@ -8,12 +8,25 @@ Load and Transform Points loads interest points from n5 and transforms them into global space """ +# Name of the per-split-tile translation ViewTransform emitted by +# ``Rhapso.splitting.split_images`` for BigStitcher provenance. +# Detection stores IPs in L0 world voxel coords (see +# ``Rhapso/detection/image_reader.py`` L195-214), so this translation +# is redundant with what is already baked into the stored IPs and +# MUST NOT be composed again during matching — doing so produces +# double-translated correspondences and a k × tile_step residual +# gradient across the moving-tile grid. +SPLIT_TILE_TRANSFORM_NAME = "Image Splitting" + class LoadAndTransformPoints: - def __init__(self, data_global, xml_input_path, n5_output_path, match_type): + def __init__(self, data_global, xml_input_path, n5_output_path, match_type, + pair_with_view=None): self.data_global = data_global self.xml_input_path = xml_input_path self.n5_output_path = n5_output_path self.match_type = match_type + # If set, only generate pairs that include this view (tp, setup) + self.pair_with_view = pair_with_view def transform_interest_points(self, points, transformation_matrix): """ @@ -59,7 +72,15 @@ def _parse_affine_matrix(self, affine_text): def get_transformation_matrix(self, view_id, view_registrations): """ - Compose all affine ViewTransforms for a given view (timepoint, setup) + Compose affine ViewTransforms for a given view (timepoint, setup). + + Every transform is composed except the reserved + ``SPLIT_TILE_TRANSFORM_NAME`` ("Image Splitting"). That entry + records the split tile's world-grid position for BigStitcher + provenance, but the per-tile translation has already been + baked into stored IP coords by detection + (``image_reader.py`` L195-214 + ``difference_of_gaussian.py`` + L318). Applying it again double-translates split-tile IPs. """ try: transforms = view_registrations.get(view_id, []) @@ -70,6 +91,9 @@ def get_transformation_matrix(self, view_id, view_registrations): final_matrix = np.eye(4) for i, transform in enumerate(transforms): + if transform.get("name") == SPLIT_TILE_TRANSFORM_NAME: + continue + affine_str = transform.get("affine") if not affine_str: print(f"⚠️ No affine string in transform {i+1} for view {view_id}") @@ -447,6 +471,15 @@ def process_pair(view_a, view_b, label, view_ids_global, view_registrations): setup['pairs'] = self.expand_pairs_with_labels(setup['pairs'], view_ids_global) + # Filter pairs to only include those containing pair_with_view + if self.pair_with_view is not None: + anchor = self.pair_with_view + setup['pairs'] = [ + (va, vb, label) for va, vb, label in setup['pairs'] + if va == anchor or vb == anchor + ] + print(f"[LoadAndTransformPoints] Filtered to {len(setup['pairs'])} pairs containing {anchor}") + # launch Ray tasks futures = [ process_pair.remote(view_a, view_b, label, view_ids_global, view_registrations) diff --git a/Rhapso/matching/ransac_matching.py b/Rhapso/matching/ransac_matching.py index fb55c73..1eaeeb6 100644 --- a/Rhapso/matching/ransac_matching.py +++ b/Rhapso/matching/ransac_matching.py @@ -265,18 +265,27 @@ def compute_ransac(self, candidates): return best_inliers, best_model def create_candidates(self, desc_a, desc_b): - match_list = [] - - for a in range(1): - for b in range(1): + """Enumerate all subset pairs (SubsetMatcher style). + + Each descriptor's ``subsets`` field contains C(n+r, n) subsets + of relative neighbor vectors. We pair every subset from A with + every subset from B (index-paired within each subset) and return + all combinations so ``descriptor_distance`` can pick the best. + """ + subsets_a = desc_a['subsets'] + subsets_b = desc_b['subsets'] + match_list = [] + for sa in subsets_a: + for sb in subsets_b: + # sa and sb are arrays of shape (num_neighbors, 3) + # relative vectors, sorted by distance from basis point matches = [] - for i in range(self.num_required_neighbors): - point_match = (desc_a['relative_descriptors'][i], desc_b['relative_descriptors'][i]) + for i in range(len(sa)): + point_match = (sa[i], sb[i]) matches.append(point_match) - match_list.append(matches) - + return match_list def descriptor_distance(self, desc_a, desc_b): @@ -320,7 +329,7 @@ def create_simple_point_descriptors(self, tree, points_array, idx, num_required_ elif len(neighbors) > num_required_neighbors: idx_sets = matcher["neighbors"] - relative_vectors = neighbors - basis_point + relative_vectors = neighbors - basis_point # Final descriptor representation (as dict) descriptor = { @@ -329,7 +338,7 @@ def create_simple_point_descriptors(self, tree, points_array, idx, num_required_ "neighbors": neighbors, "relative_descriptors": relative_vectors, "matcher": matcher, - "subsets": np.stack([neighbors[list(c)] for c in idx_sets]) + "subsets": np.stack([relative_vectors[list(c)] for c in idx_sets]) } descriptors.append(descriptor) diff --git a/Rhapso/matching/xml_parser.py b/Rhapso/matching/xml_parser.py index 25eb954..6f67a8a 100644 --- a/Rhapso/matching/xml_parser.py +++ b/Rhapso/matching/xml_parser.py @@ -7,7 +7,13 @@ class XMLParserMatching: def __init__(self, xml_input_path, input_type): - self.xml_input_path = xml_input_path + # Accept single path (str) or multiple paths (list) + if isinstance(xml_input_path, str): + self.xml_input_paths = [xml_input_path] + else: + self.xml_input_paths = list(xml_input_path) + # Keep for backwards compat (used by LoadAndTransformPoints) + self.xml_input_path = self.xml_input_paths[0] self.input_type = input_type self.data_global = None @@ -120,7 +126,10 @@ def parse_image_loader(self, root): view_setup = il.get("setup") timepoint = il.get('timepoint') if 'timepoint' in il else il.get('tp') file_path = (il.get("path") or il.findtext("path") or "").strip() - channel = file_path.split("_ch_", 1)[1].split(".ome.zarr", 1)[0] + if "_ch_" in file_path: + channel = file_path.split("_ch_", 1)[1].split(".ome.zarr", 1)[0] + else: + channel = "0" # default channel for non-exaSPIM paths image_loader_data.append( { @@ -292,11 +301,40 @@ def get_xml_content(self): return xml_content + def _merge_parsed(self, base, other): + """Merge two parsed data_global dicts.""" + # imageLoader: list concat + base['imageLoader'].extend(other['imageLoader']) + # viewRegistrations: dict merge (keyed by (tp, setup)) + base['viewRegistrations'].update(other['viewRegistrations']) + # viewsInterestPoints: dict merge (keyed by (tp, setup)) + base['viewsInterestPoints'].update(other['viewsInterestPoints']) + # viewSetup: merge sub-dicts + base['viewSetup']['byId'].update(other['viewSetup']['byId']) + base['viewSetup']['viewSizes'].update( + other['viewSetup']['viewSizes'] + ) + base['viewSetup']['viewVoxelSizes'].update( + other['viewSetup']['viewVoxelSizes'] + ) + return base + def run(self): """ - Executes the entry point of the script. + Parse all input XMLs and merge into a single data_global. """ + # Parse first XML + self.xml_input_path = self.xml_input_paths[0] xml_content = self.get_xml_content() data_global = self.parse(xml_content) + + # Parse and merge additional XMLs + for path in self.xml_input_paths[1:]: + self.xml_input_path = path + xml_content = self.get_xml_content() + other = self.parse(xml_content) + data_global = self._merge_parsed(data_global, other) + + self.data_global = data_global return data_global diff --git a/Rhapso/multiscale/array_and_chunk_prep.py b/Rhapso/multiscale/array_and_chunk_prep.py index b7104f4..deab44a 100644 --- a/Rhapso/multiscale/array_and_chunk_prep.py +++ b/Rhapso/multiscale/array_and_chunk_prep.py @@ -9,10 +9,24 @@ """ class ArrayAndChunkPrep: - def __init__(self, chunk_size: List[int], xml_path, dim: int = 5) -> None: + def __init__(self, chunk_size: List[int], xml_path=None, dim: int = 5, + voxel_size: List[float] = None) -> None: + """ + Parameters + ---------- + voxel_size : list[float], optional + Explicit ZYX voxel size (in micrometers). When provided, skips + the BigStitcher XML read. Useful when the caller already has + the base-level physical spacing in hand (e.g. from an OME-NGFF + multiscales entry). + xml_path : str, optional + BigStitcher XML path. Required if ``voxel_size`` is not given; + ignored otherwise. + """ self.chunk_size = chunk_size self.xml_path = xml_path self.dim = dim + self.voxel_size = voxel_size def voxel_size_zyx_from_xml(self) -> list[float]: if self.xml_path.startswith("s3://"): @@ -50,9 +64,16 @@ def run(self, data: ArrayLike): """ Entry point """ - voxel_size = self.voxel_size_zyx_from_xml() + if self.voxel_size is not None: + voxel_size = [float(v) for v in self.voxel_size] + elif self.xml_path is not None: + voxel_size = self.voxel_size_zyx_from_xml() + else: + raise ValueError( + "ArrayAndChunkPrep requires either voxel_size or xml_path" + ) arr = self._pad_array_n_d(data) dataset_shape = self._compute_dataset_shape(arr) full_chunks = self._clamp_chunks(dataset_shape) - + return arr, dataset_shape, full_chunks, voxel_size \ No newline at end of file diff --git a/Rhapso/multiscale/ome_metadata.py b/Rhapso/multiscale/ome_metadata.py index 7963b5f..5b1dfc5 100644 --- a/Rhapso/multiscale/ome_metadata.py +++ b/Rhapso/multiscale/ome_metadata.py @@ -23,7 +23,7 @@ def __init__(self, zarr_path: str, dataset_shape: List[int], scale_factor, voxel self.n_lvls = n_lvls self.chunk_size = chunk_size self.channel_names = channel_names - self.origin = origin or [0.0, 0.0, 0.0] + self.origin = origin # None means no translation entries in output .zattrs @staticmethod def is_s3_path(path: str) -> bool: diff --git a/Rhapso/multiscale/pyramid_executor.py b/Rhapso/multiscale/pyramid_executor.py index 246bf7a..6608c57 100644 --- a/Rhapso/multiscale/pyramid_executor.py +++ b/Rhapso/multiscale/pyramid_executor.py @@ -11,14 +11,24 @@ """ class PyramidExecutor: - def __init__(self, n_lvls: int, scale_factors, chunk_size: Tuple[int, ...], block_shape_zyx: Tuple[int, int, int], - zarr_path: str, base_level: int) -> None: + def __init__(self, n_lvls: int, scale_factors, chunk_size: Tuple[int, ...], block_shape_zyx: Tuple[int, int, int], + zarr_path: str, base_level: int, reducer=None) -> None: + """ + Parameters + ---------- + reducer : callable, optional + Function ``(array, window_size) -> downsampled_array`` applied + per block. Defaults to :meth:`windowed_mean`. Use + :meth:`windowed_min` / :meth:`windowed_max` for segmentation + or mask-style data where label/value blending is unwanted. + """ self.n_lvls = n_lvls self.scale_factors = scale_factors self.chunk_size = chunk_size self.block_shape_zyx = block_shape_zyx self.zarr_path = zarr_path self.base_level = base_level + self.reducer = reducer if reducer is not None else PyramidExecutor.windowed_mean @staticmethod def reshape_windowed(array: npt.NDArray[Any], window_size: Sequence[int]) -> npt.NDArray[Any]: @@ -40,6 +50,24 @@ def windowed_mean(array: npt.NDArray[Any], window_size: Sequence[int], **kwargs: ) return result + @staticmethod + def windowed_min(array: npt.NDArray[Any], window_size: Sequence[int], **kwargs: Any) -> npt.NDArray[Any]: + """Per-block minimum. Preserves label boundaries for integer segmentation.""" + reshaped = PyramidExecutor.reshape_windowed(array, window_size) + result: npt.NDArray[Any] = reshaped.min( + axis=tuple(range(1, reshaped.ndim, 2)), **kwargs + ) + return result + + @staticmethod + def windowed_max(array: npt.NDArray[Any], window_size: Sequence[int], **kwargs: Any) -> npt.NDArray[Any]: + """Per-block maximum. Preserves non-background labels for sparse masks.""" + reshaped = PyramidExecutor.reshape_windowed(array, window_size) + result: npt.NDArray[Any] = reshaped.max( + axis=tuple(range(1, reshaped.ndim, 2)), **kwargs + ) + return result + def store(self, channel_group, src_level: int, dst_level: int, src_shape: Tuple[int, ...], dst_shape: Tuple[int, ...], block_shape_zyx: Tuple[int, int, int], scale_factors_zyx: Tuple[int, int, int]) -> None: """ @@ -105,7 +133,10 @@ def store(self, channel_group, src_level: int, dst_level: int, src_shape: Tuple[ ) futures.append( - process_block_instruction_remote.remote(src_level, dst_level, src_slices, dst_slices, sz, sy, sx, channel_group) + process_block_instruction_remote.remote( + src_level, dst_level, src_slices, dst_slices, + sz, sy, sx, channel_group, self.reducer, + ) ) x0 = x1 @@ -181,10 +212,10 @@ def run(self, channel_group) -> None: @ray.remote def process_block_instruction_remote(src_level: int, dst_level: int, src_slices: Tuple[slice, ...], dst_slices: Tuple[slice, ...], - sz: int, sy: int, sx: int, channel_group): + sz: int, sy: int, sx: int, channel_group, reducer): src_arr = channel_group[str(src_level)] dst_arr = channel_group[str(dst_level)] src_block = np.asarray(src_arr[src_slices]) window_size = (1, 1, sz, sy, sx) - dst_block = PyramidExecutor.windowed_mean(src_block, window_size=window_size) + dst_block = reducer(src_block, window_size=window_size) dst_arr[dst_slices] = dst_block \ No newline at end of file diff --git a/Rhapso/pipelines/ray/aws/alignment_pipeline.py b/Rhapso/pipelines/ray/aws/alignment_pipeline.py index 515616d..9b1dcab 100644 --- a/Rhapso/pipelines/ray/aws/alignment_pipeline.py +++ b/Rhapso/pipelines/ray/aws/alignment_pipeline.py @@ -27,6 +27,7 @@ " combine_distance=cfg[\\\"combine_distance\\\"],\n" " chunks_per_bound=cfg[\\\"chunks_per_bound\\\"], run_type=cfg[\\\"detection_run_type\\\"],\n" " max_spots=cfg[\\\"max_spots\\\"], median_filter=cfg[\\\"median_filter\\\"],\n" + " min_peak_intensity=cfg.get(\\\"min_peak_intensity\\\"),\n" ")\n" "ipd.run()\n" "PY\n" diff --git a/Rhapso/pipelines/ray/interest_point_detection.py b/Rhapso/pipelines/ray/interest_point_detection.py index 1af6aa9..ff9c893 100644 --- a/Rhapso/pipelines/ray/interest_point_detection.py +++ b/Rhapso/pipelines/ray/interest_point_detection.py @@ -9,13 +9,15 @@ from Rhapso.detection.save_interest_points import SaveInterestPoints import boto3 import ray +import os # This class implements the interest point detection pipeline class InterestPointDetection: - def __init__(self, dsxy, dsz, min_intensity, max_intensity, sigma, threshold, file_type, xml_file_path, - image_file_prefix, xml_output_file_path, n5_output_file_prefix, combine_distance, chunks_per_bound, run_type, - max_spots, median_filter): + def __init__(self, dsxy, dsz, min_intensity, max_intensity, sigma, threshold, file_type, xml_file_path, + image_file_prefix, xml_output_file_path, n5_output_file_prefix, combine_distance, chunks_per_bound, run_type, + max_spots, median_filter, overlapping_only=True, min_peak_intensity=None, + num_cpus_per_task=1): self.dsxy = dsxy self.dsz = dsz self.min_intensity = min_intensity @@ -32,6 +34,13 @@ def __init__(self, dsxy, dsz, min_intensity, max_intensity, sigma, threshold, fi self.run_type = run_type self.max_spots = max_spots self.median_filter = median_filter + self.overlapping_only = overlapping_only + self.min_peak_intensity = min_peak_intensity + # Per-task CPU reservation. Caps detection concurrency to + # floor(num_cpus / num_cpus_per_task) — the only knob to bound + # peak resident memory when one DoG task can balloon to many GB. + # Default 1 preserves legacy behavior (concurrency = num_cpus). + self.num_cpus_per_task = max(1, int(num_cpus_per_task)) def detection(self): # Get XML file @@ -57,9 +66,10 @@ def detection(self): print("Transforms models have been created") # Use view transform matrices to find areas of overlap - overlap_detection = OverlapDetection(view_transform_matrices, dataframes, self.dsxy, self.dsz, self.image_file_prefix, self.file_type) + overlap_detection = OverlapDetection(view_transform_matrices, dataframes, self.dsxy, self.dsz, self.image_file_prefix, self.file_type, overlapping_only=self.overlapping_only) overlapping_area, new_dsxy, new_dsz, level, max_interval_size, mip_map_downsample = overlap_detection.run() print("Overlap detection is done") + print(f"[InterestPointDetection] Determined level={level}, new_dsxy={new_dsxy}, new_dsz={new_dsz}, mip_map_downsample={mip_map_downsample}") # Implement image chunking strategy as list of metadata metadata_loader = MetadataBuilder(dataframes, overlapping_area, self.image_file_prefix, self.file_type, new_dsxy, new_dsz, @@ -67,15 +77,27 @@ def detection(self): image_chunk_metadata = metadata_loader.run() print("Metadata has loaded") - # Use Ray to distribute peak detection to image chunking metadata + # Initialize Ray if not already running, using all available CPUs + if not ray.is_initialized(): + num_cpus = os.cpu_count() + ray_temp_dir = "/scratch/ray" + os.makedirs(ray_temp_dir, exist_ok=True) + ray.init(num_cpus=num_cpus, _temp_dir=ray_temp_dir) + print(f"[InterestPointDetection] Initialized Ray with {num_cpus} CPUs, temp_dir={ray_temp_dir}") + + # Use Ray to distribute peak detection to image chunking metadata @ray.remote def process_peak_detection_task(chunk_metadata, new_dsxy, new_dsz, min_intensity, max_intensity, sigma, threshold, - median_filter, mip_map_downsample): + median_filter, mip_map_downsample, min_peak_intensity): + import sys try: - difference_of_gaussian = DifferenceOfGaussian(min_intensity, max_intensity, sigma, threshold, median_filter, mip_map_downsample) + print(f"[Ray] Starting task for {chunk_metadata.get('view_id')}", flush=True) + difference_of_gaussian = DifferenceOfGaussian(min_intensity, max_intensity, sigma, threshold, median_filter, mip_map_downsample, + min_peak_intensity=min_peak_intensity) image_fetcher = ImageReader(self.file_type) view_id, interval, image_chunk, offset, lb = image_fetcher.run(chunk_metadata, new_dsxy, new_dsz) interest_points = difference_of_gaussian.run(image_chunk, offset, lb) + print(f"[Ray] Task complete for {view_id}, detected {len(interest_points['interest_points'])} points", flush=True) return { 'view_id': view_id, @@ -84,13 +106,23 @@ def process_peak_detection_task(chunk_metadata, new_dsxy, new_dsz, min_intensity 'intensities': interest_points['intensities'] } except Exception as e: + import traceback + print(f"[Ray] ERROR in task: {e}", flush=True) + print(traceback.format_exc(), flush=True) return {'error': str(e), 'view_id': chunk_metadata.get('view_id', 'unknown')} - # Submit tasks to Ray - futures = [process_peak_detection_task.remote(chunk_metadata, new_dsxy, new_dsz, self.min_intensity, self.max_intensity, - self.sigma, self.threshold, self.median_filter, mip_map_downsample) + # Submit tasks to Ray. `num_cpus` reservation per task caps + # concurrency to floor(num_cpus / num_cpus_per_task), bounding + # peak resident memory when DoG amplifies the input chunk. + futures = [process_peak_detection_task.options(num_cpus=self.num_cpus_per_task).remote( + chunk_metadata, new_dsxy, new_dsz, self.min_intensity, self.max_intensity, + self.sigma, self.threshold, self.median_filter, mip_map_downsample, + self.min_peak_intensity) for chunk_metadata in image_chunk_metadata ] + print(f"[InterestPointDetection] Submitted {len(image_chunk_metadata)} tasks " + f"with num_cpus={self.num_cpus_per_task} per task " + f"(max concurrency = {os.cpu_count() // self.num_cpus_per_task})") # Gather and process results results = ray.get(futures) @@ -112,7 +144,12 @@ def process_peak_detection_task(chunk_metadata, new_dsxy, new_dsz, min_intensity self.dsxy, self.dsz, self.min_intensity, self.max_intensity, self.sigma, self.threshold) save_interest_points.run() print("Interest points saved") - + + # Shutdown Ray to release resources for other frameworks (e.g. Dask) + if ray.is_initialized(): + ray.shutdown() + print("[InterestPointDetection] Ray shutdown complete") + def run(self): self.detection() diff --git a/Rhapso/pipelines/ray/interest_point_matching.py b/Rhapso/pipelines/ray/interest_point_matching.py index 8d1072c..93995e0 100644 --- a/Rhapso/pipelines/ray/interest_point_matching.py +++ b/Rhapso/pipelines/ray/interest_point_matching.py @@ -5,10 +5,17 @@ import ray class InterestPointMatching: - def __init__(self, xml_input_path, n5_output_path, input_type, match_type, num_neighbors, redundancy, significance, - search_radius, num_required_neighbors, model_min_matches, inlier_threshold, min_inlier_ratio, num_iterations, - regularization_weight, image_file_prefix): - self.xml_input_path = xml_input_path + def __init__(self, xml_input_path, n5_output_path, input_type, match_type, num_neighbors, redundancy, significance, + search_radius, num_required_neighbors, model_min_matches, inlier_threshold, min_inlier_ratio, num_iterations, + regularization_weight, image_file_prefix, pair_with_view=None): + # Accept single path (str) or multiple paths (list) + if isinstance(xml_input_path, str): + self.xml_input_path = xml_input_path + self.xml_input_paths = [xml_input_path] + else: + self.xml_input_paths = list(xml_input_path) + self.xml_input_path = self.xml_input_paths[0] + self.pair_with_view = pair_with_view self.n5_output_path = n5_output_path self.input_type = input_type self.match_type = match_type @@ -25,13 +32,14 @@ def __init__(self, xml_input_path, n5_output_path, input_type, match_type, num_n self.image_file_prefix = image_file_prefix def match(self): - # Load XML - parser = XMLParserMatching(self.xml_input_path, self.input_type) + # Load XML(s) + parser = XMLParserMatching(self.xml_input_paths, self.input_type) data_global = parser.run() print("XML loaded and parsed") # Load and transform points - data_loader = LoadAndTransformPoints(data_global, self.xml_input_path, self.n5_output_path, self.match_type) + data_loader = LoadAndTransformPoints(data_global, self.xml_input_path, self.n5_output_path, self.match_type, + pair_with_view=self.pair_with_view) process_pairs, view_registrations = data_loader.run() print("Points loaded and transformed into global space") diff --git a/Rhapso/pipelines/ray/local/alignment_pipeline.py b/Rhapso/pipelines/ray/local/alignment_pipeline.py index 3bff875..25fa98f 100644 --- a/Rhapso/pipelines/ray/local/alignment_pipeline.py +++ b/Rhapso/pipelines/ray/local/alignment_pipeline.py @@ -32,6 +32,7 @@ run_type=config['detection_run_type'], max_spots=config['max_spots'], median_filter=config['median_filter'], + min_peak_intensity=config.get('min_peak_intensity'), ) # INTEREST POINT MATCHING RIGID diff --git a/Rhapso/pipelines/ray/multiscale.py b/Rhapso/pipelines/ray/multiscale.py index 84f3f22..ff59f16 100644 --- a/Rhapso/pipelines/ray/multiscale.py +++ b/Rhapso/pipelines/ray/multiscale.py @@ -10,8 +10,24 @@ """ class MultiScale: - def __init__(self, xml_path, zarr_path: str, chunk_size: List[int], n_lvls: int, scale_factor, - target_block_size_mb: int, base_level: int): + def __init__(self, zarr_path: str, chunk_size: List[int], n_lvls: int, scale_factor, + target_block_size_mb: int, base_level: int, xml_path=None, + voxel_size: List[float] = None, reducer=None): + """ + Parameters + ---------- + voxel_size : list[float], optional + Explicit ZYX voxel size (in micrometers) for the base level. + When provided, overrides reading voxel size from a + BigStitcher XML. Required if ``xml_path`` is not given. + xml_path : str, optional + BigStitcher XML path. Ignored when ``voxel_size`` is given. + reducer : callable, optional + Per-block downsampling function forwarded to + :class:`PyramidExecutor`. Defaults to windowed-mean. Use + ``PyramidExecutor.windowed_min`` / ``windowed_max`` for + integer segmentation or mask data. + """ self.xml_path = xml_path self.zarr_path = zarr_path self.chunk_size = chunk_size @@ -19,18 +35,25 @@ def __init__(self, xml_path, zarr_path: str, chunk_size: List[int], n_lvls: int, self.scale_factor = scale_factor self.target_block_size_mb = target_block_size_mb self.base_level = base_level + self.voxel_size = voxel_size + self.reducer = reducer def multiscale(self) -> None: array = da.from_zarr(f"{self.zarr_path}/{self.base_level}") print(f"[MultiScale] Loading base level from {self.zarr_path}/{self.base_level}") # Normalize to TCZYX + clamp chunks - prep = ArrayAndChunkPrep(self.chunk_size, self.xml_path, dim=5) + prep = ArrayAndChunkPrep( + self.chunk_size, xml_path=self.xml_path, dim=5, + voxel_size=self.voxel_size, + ) array, dataset_shape, chunk_size, voxel_size = prep.run(array) print(f"[MultiScale] Prepared array shape={dataset_shape}, chunks={chunk_size}") - # Open root + channel group and write OME metadata - ome = OMEMetadata(self.zarr_path, list(dataset_shape), self.scale_factor, voxel_size, self.n_lvls, list(chunk_size)) + # Open root + channel group and write OME metadata. + # origin=None suppresses half-pixel translation injection so the + # output .zattrs has pure-scale entries matching the fixed image. + ome = OMEMetadata(self.zarr_path, list(dataset_shape), self.scale_factor, voxel_size, self.n_lvls, list(chunk_size), origin=None) channel_group, stack_name, scale_factors = ome.run() print(f"[MultiScale] Using channel group '{stack_name}'") @@ -41,7 +64,10 @@ def multiscale(self) -> None: print(f"[MultiScale] Block shape ZYX={block_shape_zyx}") # Execute multiscale pyramid - executor = PyramidExecutor(self.n_lvls, scale_factors, tuple(chunk_size), block_shape_zyx, self.zarr_path, self.base_level) + executor = PyramidExecutor( + self.n_lvls, scale_factors, tuple(chunk_size), block_shape_zyx, + self.zarr_path, self.base_level, reducer=self.reducer, + ) executor.run(channel_group) print("[MultiScale] Pyramid build complete") diff --git a/Rhapso/pipelines/ray/solver.py b/Rhapso/pipelines/ray/solver.py index 4520b3c..4b83205 100644 --- a/Rhapso/pipelines/ray/solver.py +++ b/Rhapso/pipelines/ray/solver.py @@ -8,6 +8,7 @@ from Rhapso.solver.connected_graphs import ConnectedGraphs from Rhapso.solver.concatenate_models import ConcatenateModels from Rhapso.solver.save_results import SaveResults +import xml.etree.ElementTree as ET import boto3 """ @@ -15,12 +16,18 @@ """ class Solver: - def __init__(self, xml_file_path_output, n5_input_path, xml_file_path, run_type, relative_threshold, absolute_threshold, - min_matches, damp, regularization_weight, max_iterations, max_allowed_error, max_plateauwidth, metrics_output_path, - fixed_tile): + def __init__(self, xml_file_path_output, n5_input_path, xml_file_path, run_type, relative_threshold, absolute_threshold, + min_matches, damp, regularization_weight, max_iterations, max_allowed_error, max_plateauwidth, metrics_output_path, + fixed_tile, skip_global_optimization=False): self.xml_file_path_output = xml_file_path_output self.n5_input_path = n5_input_path - self.xml_file_path = xml_file_path + # Accept single path (str) or multiple paths (list) + if isinstance(xml_file_path, str): + self.xml_file_path = xml_file_path + self.xml_file_paths = [xml_file_path] + else: + self.xml_file_paths = list(xml_file_path) + self.xml_file_path = self.xml_file_paths[0] self.run_type = run_type self.relative_threshold = relative_threshold self.absolute_threshold = absolute_threshold @@ -32,23 +39,54 @@ def __init__(self, xml_file_path_output, n5_input_path, xml_file_path, run_type, self.max_plateauwidth = max_plateauwidth self.metrics_output_path = metrics_output_path self.fixed_tile = fixed_tile + self.skip_global_optimization = skip_global_optimization self.groups = None self.s3 = boto3.client('s3') - def solve(self): - # Get XML file - if self.xml_file_path.startswith("s3://"): - no_scheme = self.xml_file_path.replace("s3://", "", 1) + def _read_xml(self, path): + """Read XML content from local path or S3.""" + if path.startswith("s3://"): + no_scheme = path.replace("s3://", "", 1) bucket, key = no_scheme.split("/", 1) s3 = boto3.client("s3") response = s3.get_object(Bucket=bucket, Key=key) - xml_file = response["Body"].read().decode("utf-8") - else: - with open(self.xml_file_path, "r", encoding="utf-8") as f: - xml_file = f.read() + return response["Body"].read().decode("utf-8") + else: + with open(path, "r", encoding="utf-8") as f: + return f.read() + + def _merge_xml_strings(self, xml_contents): + """Merge multiple XML strings into one for SaveResults output. + + Takes the first XML as base and appends ViewRegistrations + and ViewInterestPointsFiles from subsequent XMLs. + """ + if len(xml_contents) == 1: + return xml_contents[0] + + base_root = ET.fromstring(xml_contents[0]) + base_vr = base_root.find("ViewRegistrations") + base_vip = base_root.find("ViewInterestPoints") + + for xml_str in xml_contents[1:]: + other_root = ET.fromstring(xml_str) + other_vr = other_root.find("ViewRegistrations") + if other_vr is not None and base_vr is not None: + for vr in other_vr.findall("ViewRegistration"): + base_vr.append(vr) + other_vip = other_root.find("ViewInterestPoints") + if other_vip is not None and base_vip is not None: + for vi in other_vip.findall("ViewInterestPointsFile"): + base_vip.append(vi) + + return ET.tostring(base_root, encoding="unicode") - # Load XML data into dataframes - processor = XMLToDataFrameSolver(xml_file) + def solve(self): + # Get XML file(s) + xml_contents = [self._read_xml(p) for p in self.xml_file_paths] + + # Load XML data into dataframes + processor = XMLToDataFrameSolver(xml_contents) dataframes = processor.run() print("XML loaded") @@ -79,13 +117,18 @@ def solve(self): tc = pre_align_tiles.run(tiles) print("Tiles are pre-aligned") - # Update all points with transform models and iterate through all tiles (views) and optimize alignment - global_optimization = GlobalOptimization(tc, self.relative_threshold, self.absolute_threshold, self.min_matches, self.damp, - self.regularization_weight, self.max_iterations, self.max_allowed_error, - self.max_plateauwidth, self.run_type, self.metrics_output_path) - tiles, validation_stats = global_optimization.run() - print("Global optimization complete") - + if self.skip_global_optimization: + tiles = tc + validation_stats = {} + print("Global optimization skipped") + else: + # Update all points with transform models and iterate through all tiles (views) and optimize alignment + global_optimization = GlobalOptimization(tc, self.relative_threshold, self.absolute_threshold, self.min_matches, self.damp, + self.regularization_weight, self.max_iterations, self.max_allowed_error, + self.max_plateauwidth, self.run_type, self.metrics_output_path) + tiles, validation_stats = global_optimization.run() + print("Global optimization complete") + if(self.run_type == "split-affine"): # Combine splits into groups @@ -116,7 +159,8 @@ def solve(self): print("Models and metrics have been combined") # Save results to xml - one new affine matrix per view registration - save_results = SaveResults(tiles, xml_file, self.xml_file_path_output, self.run_type, validation_stats, self.n5_input_path) + merged_xml = self._merge_xml_strings(xml_contents) + save_results = SaveResults(tiles, merged_xml, self.xml_file_path_output, self.run_type, validation_stats, self.n5_input_path) save_results.run() print("Results have been saved") diff --git a/Rhapso/pipelines/ray/split_dataset.py b/Rhapso/pipelines/ray/split_dataset.py index 6054cfd..caba016 100644 --- a/Rhapso/pipelines/ray/split_dataset.py +++ b/Rhapso/pipelines/ray/split_dataset.py @@ -5,6 +5,7 @@ from Rhapso.split_dataset.save_points import SavePoints import boto3 import ray +import os class SplitDataset: def __init__(self, xml_file_path, xml_output_file_path, n5_path, point_density, min_points, max_points, error, exclude_radius, @@ -48,6 +49,14 @@ def split(self): save_xml.run() print("XML saved") + # Initialize Ray if not already running + if not ray.is_initialized(): + num_cpus = os.cpu_count() + ray_temp_dir = "/scratch/ray" + os.makedirs(ray_temp_dir, exist_ok=True) + ray.init(num_cpus=num_cpus, _temp_dir=ray_temp_dir) + print(f"[SplitDataset] Initialized Ray with {num_cpus} CPUs, temp_dir={ray_temp_dir}") + @ray.remote def distribute_points_saving(label_entries, n5_path): save_points = SavePoints(label_entries, n5_path) diff --git a/Rhapso/pipelines/utils.py b/Rhapso/pipelines/utils.py new file mode 100644 index 0000000..0b33c38 --- /dev/null +++ b/Rhapso/pipelines/utils.py @@ -0,0 +1,26 @@ +""" +Utility functions for pipelines +""" + + +def fetch_local_xml(file_path): + """ + Read XML content from a local file. + + Parameters + ---------- + file_path : str + Path to the XML file + + Returns + ------- + str + XML file contents + """ + try: + with open(file_path, "r", encoding="utf-8") as file: + return file.read() + except FileNotFoundError: + raise FileNotFoundError(f"Could not find XML file at '{file_path}'") + except Exception as e: + raise RuntimeError(f"Error reading XML file at '{file_path}': {e}") diff --git a/Rhapso/solver/xml_to_dataframe_solver.py b/Rhapso/solver/xml_to_dataframe_solver.py index 6f6e471..6d6e681 100644 --- a/Rhapso/solver/xml_to_dataframe_solver.py +++ b/Rhapso/solver/xml_to_dataframe_solver.py @@ -7,7 +7,13 @@ class XMLToDataFrameSolver: def __init__(self, xml_file): - self.xml_content = xml_file + # Accept single XML string or list of XML strings + if isinstance(xml_file, str): + self.xml_contents = [xml_file] + else: + self.xml_contents = list(xml_file) + # Keep for backwards compat + self.xml_content = self.xml_contents[0] def parse_image_loader_zarr(self, root): """ @@ -195,19 +201,31 @@ def check_length(self, root): length = False return length + def _parse_single(self, xml_content): + """Parse a single XML string into DataFrames.""" + root = ET.fromstring(xml_content) + try: + image_loader = self.route_image_loader(root) + except Exception: + image_loader = pd.DataFrame() + return { + "image_loader": image_loader, + "view_setups": self.parse_view_setups(root), + "view_registrations": self.parse_view_registrations(root), + "view_interest_points": self.parse_view_interest_points(root), + } + def run(self): """ - Executes the entry point of the script. + Parse all input XMLs and merge into combined DataFrames. """ - root = ET.fromstring(self.xml_content) - image_loader_df = self.route_image_loader(root) - view_setups_df = self.parse_view_setups(root) - view_registrations_df = self.parse_view_registrations(root) - view_interest_points_df = self.parse_view_interest_points(root) + result = self._parse_single(self.xml_contents[0]) - return { - "image_loader": image_loader_df, - "view_setups": view_setups_df, - "view_registrations": view_registrations_df, - "view_interest_points": view_interest_points_df, - } + for xml_content in self.xml_contents[1:]: + other = self._parse_single(xml_content) + for key in result: + result[key] = pd.concat( + [result[key], other[key]], ignore_index=True + ) + + return result diff --git a/Rhapso/split_dataset/compute_grid_rules.py b/Rhapso/split_dataset/compute_grid_rules.py index bfc611c..1f99429 100644 --- a/Rhapso/split_dataset/compute_grid_rules.py +++ b/Rhapso/split_dataset/compute_grid_rules.py @@ -23,11 +23,10 @@ def closest_larger_long_divisible_by(self, a, b): return int(a + b - (a % b)) - def find_min_step_size(self): + def find_min_step_size(self, lowest_resolution=(1.0, 1.0, 1.0)): """ Compute the minimal integer step size per axis (X,Y,Z) that is compatible with the chosen lowest resolution """ - lowest_resolution=(64.0, 64.0, 64.0) min_step_size = [1, 1, 1] for d, r in enumerate(lowest_resolution): diff --git a/Rhapso/split_dataset/save_xml.py b/Rhapso/split_dataset/save_xml.py index 14b354d..1336880 100644 --- a/Rhapso/split_dataset/save_xml.py +++ b/Rhapso/split_dataset/save_xml.py @@ -577,9 +577,12 @@ def _norm_id(raw): outer_timepoints = ch break if outer_timepoints is None: - outer_timepoints = ET.Element('Timepoints', {'type': 'pattern'}) - ip = ET.SubElement(outer_timepoints, 'integerpattern') - ip.text = "0" + tps = sorted({int(v['old_view'][0]) for v in self.self_definition if v['old_view'][0] is not None}) + first_tp = str(tps[0]) if tps else "0" + last_tp = str(tps[-1]) if tps else "0" + outer_timepoints = ET.Element('Timepoints', {'type': 'range'}) + ET.SubElement(outer_timepoints, 'first').text = first_tp + ET.SubElement(outer_timepoints, 'last').text = last_tp # place right after ViewSetups children = list(outer_seq) insert_idx = children.index(view_setups) + 1 if view_setups in children else len(children) diff --git a/Rhapso/split_dataset/split_images.py b/Rhapso/split_dataset/split_images.py index 0126528..6e27b71 100644 --- a/Rhapso/split_dataset/split_images.py +++ b/Rhapso/split_dataset/split_images.py @@ -8,16 +8,16 @@ import math class SplitImages: - def __init__(self, target_image_size, target_overlap, min_step_size, data_gloabl, n5_path, point_density, min_points, max_points, + def __init__(self, target_image_size, target_overlap, min_step_size, data_global, n5_path, point_density, min_points, max_points, error, excludeRadius): self.target_image_size = target_image_size self.target_overlap = target_overlap self.min_step_size = min_step_size - self.data_global = data_gloabl - self.image_loader_df = data_gloabl['image_loader'] - self.view_setups_df = data_gloabl['view_setups'] - self.view_registrations_df = data_gloabl['view_registrations'] - self.view_interest_points_df = data_gloabl['view_interest_points'] + self.data_global = data_global + self.image_loader_df = data_global['image_loader'] + self.view_setups_df = data_global['view_setups'] + self.view_registrations_df = data_global['view_registrations'] + self.view_interest_points_df = data_global['view_interest_points'] self.n5_path = n5_path self.point_density = point_density self.min_points = min_points @@ -96,21 +96,33 @@ def split_dims(self, input, i, final_size, overlap): to_val = 0 from_val = input_min[i] - while to_val < input[i]: + while to_val < input[i] - 1: to_val = min(input[i], from_val + final_size - 1) dim_intervals.append((from_val, to_val)) from_val = to_val - overlap + 1 return dim_intervals - def last_image_size(self, l, s, o): - num = l - 2 * (s - o) - o - den = s - o - rem = num % den if num >= 0 else -((-num) % den) - size = o + rem - if size < 0: - size = l + size - return size + def last_image_size(self, L, S, O): + """ + Parameters + ---------- + l : int + The length of the input dimension. + s : int + The target split-tile size in this dimension + o : int + The size of ovelap in this dimension. + """ + + stride = S - O + if not (0 <= O < S): + raise ValueError("Require 0 <= O < S") + if L <= 0: + raise ValueError("Require L > 0") + + start_last = ((max(L - S, 0)) // stride) * stride + return L - start_last # will be S when it fits perfectly def distribute_intervals_fixed_overlap(self, input): input = list(map(int, input.split())) @@ -127,7 +139,7 @@ def distribute_intervals_fixed_overlap(self, input): length = input[i] if length <= self.target_image_size[i]: - pass + dim_intervals.append((0, length - 1)) else: l = length @@ -338,139 +350,143 @@ def split_images(self, timepoints, interest_points, fake_label): } new_registrations[(new_view_id_key)] = new_view_registration - - new_v_ip_l = [] - - old_v_ip_l = { - 'points': interest_points[old_view_id], - 'setup': old_id, - 'timepoint': t, - } - - id = 0 - new_ip1 = [] - old_ip_l1 = old_v_ip_l['points'] - old_ip_1 = deepcopy(old_ip_l1['points']) - - for ip in old_ip_1: - if self.contains(ip, interval): - l = deepcopy(ip) - for j in range(len(interval[0])): - l[j] -= interval[0][j] - - new_ip1.append((id, l)) - id += 1 - new_ip_l1 = { - 'base_directory': old_ip_l1['base_path'], - 'corresponding_interest_points': None, - 'interest_points': new_ip1, - 'modified_corresponding_interest_points': None, - 'modified_interest_points': None, - 'n5_path': f"interestpoints.n5/tpId_{t}_viewSetupId_{new_view_id['setup']}/beads_split", - 'xml_n5_path': f"tpId_{t}_viewSetupId_{new_view_id['setup']}/{fake_label}", - "parameters": old_ip_l1['parameters_split'] - } - - new_v_ip_l.append({ - 'label': "beads_split", - 'ip_list': new_ip_l1 - }) - - new_ip = [] - id = 0 - - for j in range(i): - other_interval = intervals[j] - intersection = self.intersect(interval, other_interval) + new_v_ip_l = [] + if old_view_id in interest_points: + old_v_ip_l = { + 'points': interest_points[old_view_id], + 'setup': old_id, + 'timepoint': t, + } + + id = 0 + new_ip1 = [] + old_ip_l1 = old_v_ip_l['points'] + old_ip_1 = deepcopy(old_ip_l1['points']) - if not self.is_empty(intersection): - other_setup = interval_to_view_setup[(tuple(other_interval[0]), tuple(other_interval[1]))] - other_view_id = f"timepoint: {t}, setup: {other_setup['id']}" - other_ip_list = new_interest_points[other_view_id] - - n = len(interval[0]) - num_pixels = 1 - - for k in range(n): - num_pixels *= (intersection[1][k] - intersection[0][k] + 1) - - num_points = min(self.max_points, max(self.min_points, math.ceil(self.point_density * num_pixels / (100.0*100.0*100.0)))) - other_points = (next((x for x in other_ip_list if x.get("label") == fake_label), {"ip_list": {}})["ip_list"].get("interest_points") or []) - other_id = len(other_points) - - tree2 = None - search2 = None + for ip in old_ip_1: + if self.contains(ip, interval): + l = deepcopy(ip) + for j in range(len(interval[0])): + l[j] -= interval[0][j] + + new_ip1.append((id, l)) + id += 1 + + new_ip_l1 = { + 'base_directory': old_ip_l1['base_path'], + 'corresponding_interest_points': None, + 'interest_points': new_ip1, + 'modified_corresponding_interest_points': None, + 'modified_interest_points': None, + 'n5_path': f"interestpoints.n5/tpId_{t}_viewSetupId_{new_view_id['setup']}/beads_split", + 'xml_n5_path': f"tpId_{t}_viewSetupId_{new_view_id['setup']}/{fake_label}", + "parameters": old_ip_l1['parameters_split'] + } + + new_v_ip_l.append({ + 'label': "beads_split", + 'ip_list': new_ip_l1 + }) + + if self.max_points > 0: + new_ip = [] + id = 0 - if self.exclude_radius > 0: - other_ip_global = [] + for j in range(i): + other_interval = intervals[j] + intersection = self.intersect(interval, other_interval) - for k, ip in enumerate(other_points): - l = deepcopy(ip[1]) + if not self.is_empty(intersection): + other_setup = interval_to_view_setup[(tuple(other_interval[0]), tuple(other_interval[1]))] + other_view_id = f"timepoint: {t}, setup: {other_setup['id']}" + other_ip_list = new_interest_points[other_view_id] - for m in range(n): - l[m] = l[m] + other_interval[0][m] - - other_ip_global.append((k, l)) + n = len(interval[0]) + num_pixels = 1 - if len(other_ip_global) > 0: - coords = np.vstack([l for _, l in other_ip_global]) - tree2 = cKDTree(coords) + for k in range(n): + num_pixels *= (intersection[1][k] - intersection[0][k] + 1) + + num_points = min(self.max_points, max(self.min_points, math.ceil(self.point_density * num_pixels / (100.0*100.0*100.0)))) + other_points = (next((x for x in other_ip_list if x.get("label") == fake_label), {"ip_list": {}})["ip_list"].get("interest_points") or []) + other_id = len(other_points) - def search2(q_point_global, radius=self.exclude_radius): - idxs = tree2.query_ball_point(np.asarray(q_point_global, float), radius) - return [other_ip_global[k] for k in idxs] - else: tree2 = None search2 = None - - else: - tree2 = None - search2 = None - - tmp = [0.0] * n - for k in range(num_points): - p = [0.0] * n - op = [0.0] * n - - for d in range(n): - l = rnd.random() * (intersection[1][d] - intersection[0][d] + 1) + intersection[0][d] - p[d] = (l + (rnd.random() - 0.5) * self.error) - interval[0][d] - op[d] = (l + (rnd.random() - 0.5) * self.error) - other_interval[0][d] - tmp[d] = l - - num_neighbors = 0 - if self.exclude_radius > 0: - tmp_ip = (0, np.asarray(tmp, dtype=float)) + if self.exclude_radius > 0: + other_ip_global = [] + + for k, ip in enumerate(other_points): + l = deepcopy(ip[1]) + + for m in range(n): + l[m] = l[m] + other_interval[0][m] + + other_ip_global.append((k, l)) + + if len(other_ip_global) > 0: + coords = np.vstack([l for _, l in other_ip_global]) + tree2 = cKDTree(coords) + + def search2(q_point_global, radius=self.exclude_radius): + idxs = tree2.query_ball_point(np.asarray(q_point_global, float), radius) + return [other_ip_global[k] for k in idxs] + else: + tree2 = None + search2 = None - if search2 is not None: - neighbors = search2(tmp_ip[1], self.exclude_radius) - num_neighbors += len(neighbors) - - if num_neighbors == 0: - new_ip.append((id, p)) - other_points.append((other_id, op)) - id += 1 - other_id += 1 + else: + tree2 = None + search2 = None + + tmp = [0.0] * n + + for k in range(num_points): + p = [0.0] * n + op = [0.0] * n + + for d in range(n): + l = rnd.random() * (intersection[1][d] - intersection[0][d] + 1) + intersection[0][d] + p[d] = (l + (rnd.random() - 0.5) * self.error) - interval[0][d] + op[d] = (l + (rnd.random() - 0.5) * self.error) - other_interval[0][d] + tmp[d] = l + + num_neighbors = 0 + if self.exclude_radius > 0: + tmp_ip = (0, np.asarray(tmp, dtype=float)) + + if search2 is not None: + neighbors = search2(tmp_ip[1], self.exclude_radius) + num_neighbors += len(neighbors) + + if num_neighbors == 0: + new_ip.append((id, p)) + other_points.append((other_id, op)) + id += 1 + other_id += 1 + + next(x for x in other_ip_list if x.get("label") == fake_label)["ip_list"]["interest_points"] = other_points - next(x for x in other_ip_list if x.get("label") == fake_label)["ip_list"]["interest_points"] = other_points - - new_ip_l = { - 'base_directory': old_ip_l1['base_path'], - 'corresponding_interest_points': None, - 'interest_points': new_ip, - 'modified_corresponding_interest_points': None, - 'modified_interest_points': None, - 'n5_path': f"interestpoints.n5/tpId_{t}_viewSetupId_{new_view_id['setup']}/{fake_label}", - 'xml_n5_path': f"tpId_{t}_viewSetupId_{new_view_id['setup']}/{fake_label}", - "parameters": old_ip_l1['parameters_fake'] - } - - new_v_ip_l.append({ - 'label': fake_label, - 'ip_list': new_ip_l - }) + new_ip_l = { + 'base_directory': old_ip_l1['base_path'], + 'corresponding_interest_points': None, + 'interest_points': new_ip, + 'modified_corresponding_interest_points': None, + 'modified_interest_points': None, + 'n5_path': f"interestpoints.n5/tpId_{t}_viewSetupId_{new_view_id['setup']}/{fake_label}", + 'xml_n5_path': f"tpId_{t}_viewSetupId_{new_view_id['setup']}/{fake_label}", + "parameters": old_ip_l1['parameters_fake'] + } + + new_v_ip_l.append({ + 'label': fake_label, + 'ip_list': new_ip_l + }) + + if len(new_v_ip_l) > 0: + new_interest_points[new_view_id_key] = new_v_ip_l self.setup_definition.append({ 'interval': interval, @@ -484,7 +500,6 @@ def search2(q_point_global, radius=self.exclude_radius): 'old_models': transform_list }) - new_interest_points[new_view_id_key] = new_v_ip_l new_id += 1 return new_interest_points @@ -493,6 +508,9 @@ def load_interest_points(self, fake_label): full_path = self.n5_path + "interestpoints.n5" interest_points = {} + if self.view_interest_points_df.empty: + return {} + if full_path.startswith("s3://"): path = full_path.rstrip("/") s3 = s3fs.S3FileSystem(anon=False) diff --git a/Rhapso/split_dataset/xml_to_dataframe_split.py b/Rhapso/split_dataset/xml_to_dataframe_split.py index abe13aa..a4e9a23 100644 --- a/Rhapso/split_dataset/xml_to_dataframe_split.py +++ b/Rhapso/split_dataset/xml_to_dataframe_split.py @@ -17,10 +17,17 @@ def parse_image_loader_zarr(self, root): for il in root.findall(".//ImageLoader/zgroups/zgroup"): view_setup = il.get("setup") - timepoint = il.get("timepoint") - file_path = il.find("path").text if il.find("path") is not None else None - - channel = file_path.split("_ch_", 1)[1].split(".ome.zarr", 1)[0] + timepoint = il.get("tp") or il.get("timepoint") + file_path = il.get("path") + if file_path is None: + element_string = ET.tostring(il, encoding='unicode') + raise ValueError(f"zgroup element missing 'path' attribute: {element_string}") + + # default to channel 0 if not parseable from the path name + try: + channel = file_path.split("_ch_", 1)[1].split(".ome.zarr", 1)[0] + except (IndexError, AttributeError): + channel = 0 image_loader_data.append( { diff --git a/setup.py b/setup.py index 0f075d8..42ac315 100644 --- a/setup.py +++ b/setup.py @@ -19,11 +19,11 @@ }, packages=find_packages(), install_requires=[ - 'pandas', + 'pandas==3.0.2', 'dask[array]==2024.12.1', 'zarr==2.18.3', 'scipy==1.13.1', - 'scikit-image', + 'scikit-image==0.22.0', 'bioio==1.3.0', 'bioio-tifffile==1.0.0', 'tifffile==2025.1.10', @@ -33,14 +33,14 @@ 'matplotlib==3.10.0', 'memory-profiler==0.61.0', 's3fs==2024.12.0', - 'scikit-learn', + 'scikit-learn==1.8.0', 'click==8.2.1', - 'ray[default]==2.9.1', - 'tensorstore', - 'xmltodict', - 'nptyping', - "setuptools==68.2.2" + 'ray==2.54.1', + 'tensorstore==0.1.82', + 'xmltodict==1.0.4', + 'nptyping==2.5.0', + "setuptools==71.0.4" ], python_requires='>=3.10', classifiers=[ diff --git a/tests/XML_test_data/dataset_split.xml b/tests/XML_test_data/dataset_split.xml new file mode 100644 index 0000000..9cc492f --- /dev/null +++ b/tests/XML_test_data/dataset_split.xml @@ -0,0 +1,150 @@ + + + . + + + + 0 + Tile 0 + 500 500 100 + + um + 1.0 1.0 1.0 + + + 0 + + + + 1 + Tile 1 + 500 500 100 + + um + 1.0 1.0 1.0 + + + 1 + + + + 2 + Tile 2 + 500 500 100 + + um + 1.0 1.0 1.0 + + + 2 + + + + 3 + Tile 3 + 500 500 100 + + um + 1.0 1.0 1.0 + + + 3 + + + + + 0 + 0 + + + + + + + calibration + 1.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 + + + Image Splitting + 1.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 + + + + + calibration + 1.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 + + + Image Splitting + 1.0 0.0 0.0 300.0 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 + + + + + calibration + 1.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 + + + Image Splitting + 1.0 0.0 0.0 0.0 0.0 1.0 0.0 300.0 0.0 0.0 1.0 0.0 + + + + + calibration + 1.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 + + + Image Splitting + 1.0 0.0 0.0 300.0 0.0 1.0 0.0 300.0 0.0 0.0 1.0 0.0 + + + + + + s3://test-bucket/SPIM.ome.zarr/ + + + + + + + + 0 + Source Image + 800 800 100 + + um + 1.0 1.0 1.0 + + + + + + + 0 + 0 + 0 0 0 + 499 499 99 + + + 1 + 0 + 300 0 0 + 799 499 99 + + + 2 + 0 + 0 300 0 + 499 799 99 + + + 3 + 0 + 300 300 0 + 799 799 99 + + + + diff --git a/tests/test_data_prep/test_xml_to_dataframe.py b/tests/test_data_prep/test_xml_to_dataframe.py index 6118b84..fee6318 100644 --- a/tests/test_data_prep/test_xml_to_dataframe.py +++ b/tests/test_data_prep/test_xml_to_dataframe.py @@ -67,28 +67,26 @@ def test_parse_view_interest_points(self): xml_content = fetch_local_xml(self.xml_content_standard) self.parser = XMLToDataFrame(xml_content) root = ET.fromstring(xml_content) - df = self.parser.parse_view_interest_points(root, "data_prep") + df = self.parser.parse_view_interest_points(root) self.assertTrue(df.empty) def test_run(self): xml_content = fetch_local_xml(self.xml_content_standard) self.parser = XMLToDataFrame(xml_content) - result = self.parser.run("data_prep") + result = self.parser.run() self.assertIn("image_loader", result) self.assertIn("view_setups", result) self.assertIn("view_registrations", result) self.assertIn("view_interest_points", result) def test_interest_points_already_exist(self): + """Test that existing interest points are parsed correctly""" xml_content = fetch_local_xml(self.xml_content_interestPoints) self.parser = XMLToDataFrame(xml_content) root = ET.fromstring(xml_content) - with self.assertRaises(Exception) as context: - self.parser.parse_view_interest_points(root, "data_prep") - self.assertEqual( - str(context.exception), - "There should be no interest points in this file yet.", - ) + df = self.parser.parse_view_interest_points(root) + # Should parse existing interest points without raising an exception + self.assertIsInstance(df, pd.DataFrame) def test_no_labels(self): xml_content = fetch_local_xml(self.xml_content_no_tags) @@ -130,6 +128,58 @@ def test_no_file_mapping_exists(self): self.parser.parse_image_loader_tiff(root) self.assertEqual(str(context.exception), "There are no files in this XML") + def test_parse_image_loader_split_zarr(self): + """Test split zarr parsing with 4 tiles""" + xml_path = "tests/XML_test_data/dataset_split.xml" + xml_content = fetch_local_xml(xml_path) + self.parser = XMLToDataFrame(xml_content) + result = self.parser.run() + df = result['image_loader'] + + # Should have 4 rows (one per split tile) + self.assertEqual(len(df), 4) + + # Check required columns exist + expected_cols = {'view_setup', 'timepoint', 'crop_min', 'crop_max', 'zarr_base_path', 'file_path'} + self.assertTrue(expected_cols.issubset(set(df.columns))) + + # Check values for each tile + self.assertEqual(df.iloc[0]['view_setup'], '0') + self.assertEqual(df.iloc[0]['crop_min'], '0 0 0') + self.assertEqual(df.iloc[0]['crop_max'], '499 499 99') + self.assertEqual(df.iloc[0]['zarr_base_path'], 's3://test-bucket/SPIM.ome.zarr/') + + self.assertEqual(df.iloc[1]['view_setup'], '1') + self.assertEqual(df.iloc[1]['crop_min'], '300 0 0') + self.assertEqual(df.iloc[1]['crop_max'], '799 499 99') + + self.assertEqual(df.iloc[2]['view_setup'], '2') + self.assertEqual(df.iloc[2]['crop_min'], '0 300 0') + self.assertEqual(df.iloc[2]['crop_max'], '499 799 99') + + self.assertEqual(df.iloc[3]['view_setup'], '3') + self.assertEqual(df.iloc[3]['crop_min'], '300 300 0') + self.assertEqual(df.iloc[3]['crop_max'], '799 799 99') + + # All rows should have same file_path and timepoint + self.assertEqual(df.iloc[0]['file_path'], df.iloc[3]['file_path']) + self.assertEqual(df.iloc[0]['timepoint'], df.iloc[3]['timepoint']) + + def test_parse_view_setups_split(self): + """Test that outer ViewSetups are parsed correctly for split XML""" + xml_path = "tests/XML_test_data/dataset_split.xml" + xml_content = fetch_local_xml(xml_path) + self.parser = XMLToDataFrame(xml_content) + root = ET.fromstring(xml_content) + df = self.parser.parse_view_setups(root) + + # Should have 4 rows (outer ViewSetups only, not the inner one) + self.assertEqual(len(df), 4) + + # Check IDs + ids = sorted([df.iloc[i]['id'] for i in range(len(df))]) + self.assertEqual(ids, ['0', '1', '2', '3']) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_detection/test_difference_of_gaussian.py b/tests/test_detection/test_difference_of_gaussian.py index 0d1ee8b..c4880e1 100644 --- a/tests/test_detection/test_difference_of_gaussian.py +++ b/tests/test_detection/test_difference_of_gaussian.py @@ -13,7 +13,8 @@ class TestDifferenceOfGaussian(unittest.TestCase): def setUp(self): self.dog = DifferenceOfGaussian( - min_intensity=0, max_intensity=255, sigma=1.0, threshold=0.5 + min_intensity=0, max_intensity=255, sigma=1.0, threshold=0.5, + median_filter=False, mip_map_downsample=np.eye(4), ) self.image = np.random.rand(10, 10, 10) * 255 @@ -58,9 +59,70 @@ def test_upsample_coordinates(self): [856.6701082186948, 416.01488311517676, 3.4227515981883694], ] ) - upsampled_points = self.dog.upsample_coordinates(points, 2, 1) + upsampled_points = self.dog.upsample_coordinates(points) self.assertEqual(len(upsampled_points), len(points)) + def test_lower_bounds_added_after_upsample(self): + """Regression: apply_lower_bounds must happen AFTER + upsample_coordinates so that lb (L0 coords) is added to + L0-upsampled peaks, not to downsampled-space peaks that then + get multiplied by the pyramid scale factor. + + Prior bug: lb was added before upsample, inflating the lb + component by the scale factor and producing Z-banded IPs + when chunks_per_bound > 1. + """ + scale = 16 + half_pixel = 0.5 * (scale - 1) + mip_map_downsample = np.array([ + [scale, 0, 0, half_pixel], + [0, scale, 0, half_pixel], + [0, 0, scale, half_pixel], + [0, 0, 0, 1], + ], dtype=float) + + dog = DifferenceOfGaussian( + min_intensity=0, max_intensity=255, sigma=1.0, + threshold=0.5, median_filter=False, + mip_map_downsample=mip_map_downsample, + ) + + # Synthetic peak at downsampled coord (5, 3, 2) + peaks = np.array([[5.0, 3.0, 2.0]], dtype=np.float32) + lb = [0, 0, 99] # L0 coords — chunk starts at Z=99 + offset = 0 + + # Correct transform order: upsample then add lb + result = dog.upsample_coordinates(peaks) + result = dog.apply_lower_bounds(result, lb) + result = dog.apply_offset(result, offset) + + expected_x = 5.0 * scale + half_pixel + 0 + expected_y = 3.0 * scale + half_pixel + 0 + expected_z = 2.0 * scale + half_pixel + 99 + + np.testing.assert_allclose( + result[0], + [expected_x, expected_y, expected_z], + atol=1e-4, + err_msg="lb must be added AFTER upsample to avoid inflating " + "the offset by the pyramid scale factor", + ) + + # Verify the OLD (buggy) order would give a wrong answer + buggy = dog.apply_lower_bounds(peaks.copy(), lb) + buggy = dog.apply_offset(buggy, offset) + buggy = dog.upsample_coordinates(buggy) + buggy_z = float(buggy[0, 2]) + correct_z = float(result[0, 2]) + # Buggy Z = (2 + 99) * 16 + 7.5 = 1623.5 vs correct 138.5 + self.assertNotAlmostEqual( + buggy_z, correct_z, places=1, + msg="Old transform order must differ from correct order " + "when lb has non-zero Z", + ) + + if __name__ == "__main__": unittest.main() diff --git a/tests/test_detection/test_image_reader.py b/tests/test_detection/test_image_reader.py new file mode 100644 index 0000000..e0c10cc --- /dev/null +++ b/tests/test_detection/test_image_reader.py @@ -0,0 +1,185 @@ +import unittest +from unittest.mock import patch, MagicMock +import dask.array as da +import numpy as np + +from Rhapso.detection.image_reader import ImageReader + + +class TestImageReader(unittest.TestCase): + def test_fetch_image_data_crop_applied_before_downsampling(self): + """Test that crop is applied after transpose, before downsampling""" + reader = ImageReader(file_type='zarr') + + # Create a mock record with crop bounds + record = { + 'view_id': 'timepoint: 0, setup: 0', + 'file_path': 's3://bucket/test.zarr/0', + 'interval_key': ((0, 0, 0), (50, 50, 25), (51, 51, 26)), + 'offset': 0, + 'lb': (0, 0, 0), + 'crop_min': [2, 2, 2], + 'crop_max': [7, 7, 7] + } + + # Mock the zarr opening to return a known dask array (10x10x10) + mock_array = da.ones((1, 1, 10, 10, 10), dtype=np.float32) + + with patch('zarr.open') as mock_zarr, \ + patch('s3fs.S3FileSystem'), \ + patch('s3fs.S3Map'), \ + patch('dask.array.from_zarr', return_value=mock_array): + + # Call fetch_image_data + view_id, interval_key, chunk, offset, lower_bound = reader.fetch_image_data( + record, dsxy=1, dsz=1 + ) + + # Verify crop was applied: array should be [2:8, 2:8, 2:8] = 6x6x6 + self.assertEqual(chunk.shape, (6, 6, 6)) + self.assertEqual(view_id, 'timepoint: 0, setup: 0') + + def test_fetch_image_data_without_crop(self): + """Test backward compatibility: records without crop fields work normally""" + reader = ImageReader(file_type='zarr') + + # Record without crop fields + record = { + 'view_id': 'timepoint: 0, setup: 0', + 'file_path': 's3://bucket/test.zarr/0', + 'interval_key': ((0, 0, 0), (50, 50, 25), (51, 51, 26)), + 'offset': 0, + 'lb': (0, 0, 0), + 'crop_min': None, + 'crop_max': None + } + + # Mock the zarr opening to return a known dask array + mock_array = da.ones((1, 1, 10, 10, 10), dtype=np.float32) + + with patch('zarr.open') as mock_zarr, \ + patch('s3fs.S3FileSystem'), \ + patch('s3fs.S3Map'), \ + patch('dask.array.from_zarr', return_value=mock_array): + + # Call fetch_image_data - should not raise an error + view_id, interval_key, chunk, offset, lower_bound = reader.fetch_image_data( + record, dsxy=1, dsz=1 + ) + + # Should succeed without crop + self.assertEqual(view_id, 'timepoint: 0, setup: 0') + + def test_fetch_image_data_tiff_no_crop_error(self): + """Test that tiff mode without crop works (no changes to tiff path)""" + reader = ImageReader(file_type='tiff') + + record = { + 'view_id': 'timepoint: 0, setup: 0', + 'file_path': '/path/to/test.tif', + 'interval_key': ((0, 0, 0), (50, 50, 25), (51, 51, 26)), + 'offset': 0, + 'lb': (0, 0, 0), + 'crop_min': None, + 'crop_max': None + } + + # Mock the BioImage reader + mock_bioimage = MagicMock() + mock_dask_array = da.ones((1, 1, 1, 10, 10, 10), dtype=np.float32) + mock_bioimage.get_dask_stack.return_value = mock_dask_array + + with patch('Rhapso.detection.image_reader.CustomBioImage', return_value=mock_bioimage): + # Should not raise an error + view_id, interval_key, chunk, offset, lower_bound = reader.fetch_image_data( + record, dsxy=1, dsz=1 + ) + + self.assertEqual(view_id, 'timepoint: 0, setup: 0') + + def test_fetch_image_data_crop_bounds_validation(self): + """Test that crop bounds exceeding array dimensions raise clear error""" + reader = ImageReader(file_type='zarr') + + # Record with crop_max exceeding array dimensions + record = { + 'view_id': 'timepoint: 0, setup: 0', + 'file_path': 's3://bucket/test.zarr/0', + 'interval_key': ((0, 0, 0), (50, 50, 25), (51, 51, 26)), + 'offset': 0, + 'lb': (0, 0, 0), + 'crop_min': [0, 0, 0], + 'crop_max': [15, 5, 5] # Exceeds dimension 0 (10x10x10 array) + } + + # Mock the zarr opening to return a known dask array (10x10x10) + mock_array = da.ones((1, 1, 10, 10, 10), dtype=np.float32) + + with patch('zarr.open') as mock_zarr, \ + patch('s3fs.S3FileSystem'), \ + patch('s3fs.S3Map'), \ + patch('dask.array.from_zarr', return_value=mock_array): + + # Should raise ValueError with clear message + with self.assertRaises(ValueError) as context: + reader.fetch_image_data(record, dsxy=1, dsz=1) + + error_msg = str(context.exception) + self.assertIn('crop_max[0]=15 exceeds array dimension 0', error_msg) + self.assertIn('(shape=10)', error_msg) + + def test_fetch_image_data_negative_crop_min(self): + """Test that negative crop_min values raise clear error""" + reader = ImageReader(file_type='zarr') + + record = { + 'view_id': 'timepoint: 0, setup: 0', + 'file_path': 's3://bucket/test.zarr/0', + 'interval_key': ((0, 0, 0), (50, 50, 25), (51, 51, 26)), + 'offset': 0, + 'lb': (0, 0, 0), + 'crop_min': [-1, 0, 0], + 'crop_max': [5, 5, 5] + } + + mock_array = da.ones((1, 1, 10, 10, 10), dtype=np.float32) + + with patch('zarr.open') as mock_zarr, \ + patch('s3fs.S3FileSystem'), \ + patch('s3fs.S3Map'), \ + patch('dask.array.from_zarr', return_value=mock_array): + + with self.assertRaises(ValueError) as context: + reader.fetch_image_data(record, dsxy=1, dsz=1) + + self.assertIn('crop_min[0]=-1 is negative', str(context.exception)) + + def test_fetch_image_data_crop_min_greater_than_crop_max(self): + """Test that crop_min > crop_max raises clear error""" + reader = ImageReader(file_type='zarr') + + record = { + 'view_id': 'timepoint: 0, setup: 0', + 'file_path': 's3://bucket/test.zarr/0', + 'interval_key': ((0, 0, 0), (50, 50, 25), (51, 51, 26)), + 'offset': 0, + 'lb': (0, 0, 0), + 'crop_min': [5, 0, 0], + 'crop_max': [3, 5, 5] # crop_min[0] > crop_max[0] + } + + mock_array = da.ones((1, 1, 10, 10, 10), dtype=np.float32) + + with patch('zarr.open') as mock_zarr, \ + patch('s3fs.S3FileSystem'), \ + patch('s3fs.S3Map'), \ + patch('dask.array.from_zarr', return_value=mock_array): + + with self.assertRaises(ValueError) as context: + reader.fetch_image_data(record, dsxy=1, dsz=1) + + self.assertIn('crop_min[0]=5 > crop_max[0]=3', str(context.exception)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_detection/test_metadata_builder.py b/tests/test_detection/test_metadata_builder.py new file mode 100644 index 0000000..39ff8a2 --- /dev/null +++ b/tests/test_detection/test_metadata_builder.py @@ -0,0 +1,237 @@ +import unittest +import pandas as pd +import numpy as np + +from Rhapso.detection.metadata_builder import MetadataBuilder + + +class TestMetadataBuilder(unittest.TestCase): + def test_build_paths_split_uses_zarr_base_path(self): + """Test that split mode uses zarr_base_path for file path construction""" + # Create mock image_loader_df with split columns + image_loader_df = pd.DataFrame({ + 'view_setup': ['0', '1'], + 'timepoint': ['0', '0'], + 'file_path': ['Tile_X_0000_Y_0000_Z_0000_ch_405.zarr', 'Tile_X_0000_Y_0000_Z_0000_ch_405.zarr'], + 'crop_min': ['0 0 0', '300 0 0'], + 'crop_max': ['499 499 99', '799 499 99'], + 'zarr_base_path': ['s3://bucket/SPIM.ome.zarr/', 's3://bucket/SPIM.ome.zarr/'] + }) + + # Mock overlapping_area + overlapping_area = { + 'timepoint: 0, setup: 0': [{'lower_bound': np.array([0, 0, 0]), 'upper_bound': np.array([100, 100, 50])}], + 'timepoint: 0, setup: 1': [{'lower_bound': np.array([50, 0, 0]), 'upper_bound': np.array([150, 100, 50])}] + } + + dataframes = {'image_loader': image_loader_df} + builder = MetadataBuilder( + dataframes=dataframes, + overlapping_area=overlapping_area, + image_file_prefix='s3://bucket/SPIM.ome.zarr/', + file_type='zarr', + dsxy=1.0, + dsz=1.0, + chunks_per_bound=1, + sigma=1.0, + run_type='ray', + level=0 + ) + builder.build_paths() + + # Check that file_path uses zarr_base_path + self.assertTrue( + 'zarr' in builder.metadata[0]['file_path'], + f"File path should contain zarr path: {builder.metadata[0]['file_path']}" + ) + + def test_build_paths_split_passes_crop_bounds(self): + """Test that crop bounds are included in metadata records""" + image_loader_df = pd.DataFrame({ + 'view_setup': ['0'], + 'timepoint': ['0'], + 'file_path': ['Tile_X_0000_Y_0000_Z_0000_ch_405.zarr'], + 'crop_min': ['0 0 0'], + 'crop_max': ['499 499 99'], + 'zarr_base_path': ['s3://bucket/SPIM.ome.zarr/'] + }) + + overlapping_area = { + 'timepoint: 0, setup: 0': [{'lower_bound': np.array([0, 0, 0]), 'upper_bound': np.array([100, 100, 50])}] + } + + dataframes = {'image_loader': image_loader_df} + builder = MetadataBuilder( + dataframes=dataframes, + overlapping_area=overlapping_area, + image_file_prefix='s3://bucket/SPIM.ome.zarr/', + file_type='zarr', + dsxy=1.0, + dsz=1.0, + chunks_per_bound=0, # No chunking + sigma=1.0, + run_type='ray', + level=0 + ) + builder.build_paths() + + # Check that crop_min and crop_max are in metadata + self.assertIn('crop_min', builder.metadata[0]) + self.assertIn('crop_max', builder.metadata[0]) + self.assertEqual(builder.metadata[0]['crop_min'], [0, 0, 0]) + self.assertEqual(builder.metadata[0]['crop_max'], [499, 499, 99]) + + def test_build_paths_split_scales_crop_bounds_by_level(self): + """Test that crop bounds are scaled by 2^level""" + image_loader_df = pd.DataFrame({ + 'view_setup': ['0'], + 'timepoint': ['0'], + 'file_path': ['Tile_X_0000_Y_0000_Z_0000_ch_405.zarr'], + 'crop_min': ['300 0 0'], + 'crop_max': ['799 499 99'], + 'zarr_base_path': ['s3://bucket/SPIM.ome.zarr/'] + }) + + overlapping_area = { + 'timepoint: 0, setup: 0': [{'lower_bound': np.array([0, 0, 0]), 'upper_bound': np.array([100, 100, 50])}] + } + + dataframes = {'image_loader': image_loader_df} + # level=2 means scale by 2^2 = 4 + builder = MetadataBuilder( + dataframes=dataframes, + overlapping_area=overlapping_area, + image_file_prefix='s3://bucket/SPIM.ome.zarr/', + file_type='zarr', + dsxy=1.0, + dsz=1.0, + chunks_per_bound=0, + sigma=1.0, + run_type='ray', + level=2 + ) + builder.build_paths() + + # 300 // 4 = 75, 799 // 4 = 199, etc. + self.assertEqual(builder.metadata[0]['crop_min'], [75, 0, 0]) + self.assertEqual(builder.metadata[0]['crop_max'], [199, 124, 24]) + + def test_chunked_metadata_no_z_band_double_add(self): + """Regression: with chunks_per_bound>1, applying lb + offset to + chunk-local peaks must equal the absolute parent-frame coord — + i.e. no double-add of the chunk shift. Prior bug: lb=parent_lb + and offset=z together added the chunk shift twice, producing + Z-banded IPs (one missing band per chunk after the first). + """ + image_loader_df = pd.DataFrame({ + 'view_setup': ['0'], + 'timepoint': ['0'], + 'file_path': ['Tile_X_0000_Y_0000_Z_0000_ch_405.zarr'], + 'crop_min': ['0 0 0'], + 'crop_max': ['99 99 999'], + 'zarr_base_path': ['s3://bucket/SPIM.ome.zarr/'] + }) + # Single un-split parent region spanning Z=0..999 (XYZ-ordered + # bounds match what OverlapDetection emits). + overlapping_area = { + 'timepoint: 0, setup: 0': [{ + 'lower_bound': np.array([0, 0, 0]), + 'upper_bound': np.array([99, 99, 999]), + }], + } + dataframes = {'image_loader': image_loader_df} + builder = MetadataBuilder( + dataframes=dataframes, + overlapping_area=overlapping_area, + image_file_prefix='s3://bucket/SPIM.ome.zarr/', + file_type='zarr', + dsxy=1.0, + dsz=1.0, + chunks_per_bound=4, + sigma=1.0, + run_type='ray', + level=0, + ) + builder.build_paths() + + # Expect 4 chunked metadata records. + self.assertEqual(len(builder.metadata), 4) + + # For each chunk, verify the contract that produces correct + # global coords: peaks_global = peak_local + lb + offset. + # With our fix, lb already encodes the parent-frame chunk-start + # and offset is 0, so a peak at local (0,0,0) should map to the + # chunk's expected global Z (within `overlap` tolerance for the + # halo expansion). + prev_chunk_z_start = -1 + for entry in builder.metadata: + actual_lb = entry['lb'] + offset = entry['offset'] + interval_lb = entry['interval_key'][0] + + # offset must be 0 — the chunk shift is already in `lb`. + self.assertEqual( + offset, 0, + f"offset should be 0 (chunk-shift lives in lb); got {offset}", + ) + # `lb` must equal the chunk's interval lower bound. Same + # tuple — no separate parent-vs-chunk frames. + self.assertEqual(tuple(actual_lb), tuple(interval_lb)) + + # Chunk Z-starts must be monotonically increasing (sanity). + chunk_z_start = actual_lb[2] + self.assertGreater(chunk_z_start, prev_chunk_z_start) + prev_chunk_z_start = chunk_z_start + + # Simulate the DoG recombine step: upsample first (identity at + # level 0), then add lb (L0 coords), then add offset (0). + # A synthetic peak at chunk-local (0,0,0) in chunk i should + # land at global Z = interval_key[0][2]. + for entry in builder.metadata: + local_peak_xyz = np.array([[0.0, 0.0, 0.0]], dtype=np.float32) + # upsample_coordinates is identity at level 0 + upsampled_peak = local_peak_xyz.copy() + lb_xyz = np.array(entry['lb'], dtype=np.float32) + global_peak = upsampled_peak + lb_xyz # apply_lower_bounds + global_peak[:, 2] += entry['offset'] # apply_offset (now 0) + self.assertEqual( + float(global_peak[0, 2]), + float(entry['interval_key'][0][2]), + "peak at chunk-local (0,0,0) must map to chunk's " + "parent-frame Z without double-adding the offset", + ) + + def test_build_paths_regular_zarr_no_crop(self): + """Test backward compatibility: regular zarr has no crop fields""" + image_loader_df = pd.DataFrame({ + 'view_setup': ['0'], + 'timepoint': ['0'], + 'file_path': ['test.zarr'] + }) + + overlapping_area = { + 'timepoint: 0, setup: 0': [{'lower_bound': np.array([0, 0, 0]), 'upper_bound': np.array([100, 100, 50])}] + } + + dataframes = {'image_loader': image_loader_df} + builder = MetadataBuilder( + dataframes=dataframes, + overlapping_area=overlapping_area, + image_file_prefix='s3://bucket/', + file_type='zarr', + dsxy=1.0, + dsz=1.0, + chunks_per_bound=0, + sigma=1.0, + run_type='ray', + level=0 + ) + builder.build_paths() + + # Regular zarr should have None for crop fields + self.assertIsNone(builder.metadata[0]['crop_min']) + self.assertIsNone(builder.metadata[0]['crop_max']) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_detection/test_overlap_detection.py b/tests/test_detection/test_overlap_detection.py index e4f9e88..f6b7ad1 100644 --- a/tests/test_detection/test_overlap_detection.py +++ b/tests/test_detection/test_overlap_detection.py @@ -1,5 +1,4 @@ import unittest -<<<<<<< HEAD import numpy as np import pandas as pd @@ -55,14 +54,6 @@ def test_find_overlapping_area_empty_dataframe(self): with self.assertRaises(ValueError) as context: self.od.find_overlapping_area() self.assertEqual(str(context.exception), "Image Loader dataframe is empty.") -======= - - -class TestOverlapDetecttion(unittest.TestCase): - - def setUp(self): - pass ->>>>>>> main if __name__ == "__main__": diff --git a/tests/test_evaluaton/test_alignment_threshold.py b/tests/test_evaluation/test_alignment_threshold.py similarity index 100% rename from tests/test_evaluaton/test_alignment_threshold.py rename to tests/test_evaluation/test_alignment_threshold.py diff --git a/tests/test_evaluation/test_detection_qc.py b/tests/test_evaluation/test_detection_qc.py new file mode 100644 index 0000000..df70c5b --- /dev/null +++ b/tests/test_evaluation/test_detection_qc.py @@ -0,0 +1,360 @@ +"""Tests for Rhapso.evaluation.detection_qc module.""" +import json +import os +import tempfile +from pathlib import Path + +import numpy as np +import pytest +import zarr + +from Rhapso.evaluation.detection_qc.view_metrics import ( + ViewIPMetrics, + compute_view_metrics, + compute_all_view_metrics, + _get_ip_count_from_attributes, +) +from Rhapso.evaluation.detection_qc.sweep_analyzer import ( + SweepTrialResult, + SweepAnalyzer, +) +from Rhapso.evaluation.detection_qc.plotting import ( + generate_all_plots, + plot_sweep_ip_counts, + plot_sweep_success_rates, +) + + +@pytest.fixture +def mock_n5_dir(tmp_path): + """Create a mock N5 directory structure with IP detection outputs.""" + views = { + "tpId_0_viewSetupId_0": { + "loc": np.array([[10.0, 20.0, 5.0], [30.0, 40.0, 10.0], [50.0, 60.0, 15.0]], dtype=np.float64), + "intensities": np.array([100.0, 200.0, 300.0], dtype=np.float32), + }, + "tpId_0_viewSetupId_1": { + "loc": np.array([[5.0, 10.0, 2.0], [15.0, 25.0, 8.0]], dtype=np.float64), + "intensities": np.array([150.0, 250.0], dtype=np.float32), + }, + } + + n5_dir = tmp_path / "interestpoints.n5" + n5_store = zarr.N5Store(str(n5_dir)) + root = zarr.open(n5_store, mode="w") + + for view_name, data in views.items(): + loc = data["loc"] + intensities = data["intensities"] + ip_count = len(loc) + + group_path = f"{view_name}/beads/interestpoints" + root.create_dataset(f"{group_path}/loc", data=loc, chunks=loc.shape) + root.create_dataset(f"{group_path}/intensities", data=intensities, chunks=intensities.shape) + root.create_dataset(f"{group_path}/id", data=np.arange(ip_count, dtype=np.uint64), chunks=(ip_count,)) + + # Write attributes.json for the id dataset (fast count path) + id_dir = n5_dir / view_name / "beads" / "interestpoints" / "id" + id_dir.mkdir(parents=True, exist_ok=True) + attrs_path = id_dir / "attributes.json" + with open(attrs_path, "w") as f: + json.dump({"dimensions": [3, ip_count]}, f) + + return tmp_path + + +@pytest.fixture +def mock_empty_n5_dir(tmp_path): + """Create a mock N5 directory with no views.""" + (tmp_path / "interestpoints.n5").mkdir() + return tmp_path + + +@pytest.fixture +def sample_view_metrics(): + """Sample ViewIPMetrics for testing.""" + return [ + ViewIPMetrics( + view_id="tpId_0_viewSetupId_0", + ip_count=500, + spatial_extent_xyz=(100.0, 80.0, 30.0), + spatial_std_xyz=(25.0, 20.0, 8.0), + density=0.002, + intensity_mean=200.0, + intensity_std=50.0, + intensity_min=50.0, + intensity_max=400.0, + meets_target=True, + ), + ViewIPMetrics( + view_id="tpId_0_viewSetupId_1", + ip_count=100, + spatial_extent_xyz=(60.0, 50.0, 20.0), + spatial_std_xyz=(15.0, 12.0, 5.0), + density=0.0017, + intensity_mean=180.0, + intensity_std=40.0, + intensity_min=60.0, + intensity_max=350.0, + meets_target=False, + ), + ] + + +@pytest.fixture +def sample_trials(sample_view_metrics): + """Sample sweep trials for testing.""" + return [ + SweepTrialResult( + multiscale="5", + sigma=4.0, + trial_index=0, + success=False, + n5_output_path="/scratch/trial_0", + view_metrics=[ + ViewIPMetrics( + view_id="v0", ip_count=50, + spatial_extent_xyz=(10.0, 10.0, 5.0), + spatial_std_xyz=(3.0, 3.0, 1.5), + density=0.1, intensity_mean=100.0, intensity_std=20.0, + intensity_min=50.0, intensity_max=200.0, meets_target=False, + ), + ], + ), + SweepTrialResult( + multiscale="4", + sigma=3.0, + trial_index=1, + success=False, + n5_output_path="/scratch/trial_1", + view_metrics=[ + ViewIPMetrics( + view_id="v0", ip_count=200, + spatial_extent_xyz=(40.0, 35.0, 15.0), + spatial_std_xyz=(10.0, 9.0, 4.0), + density=0.01, intensity_mean=150.0, intensity_std=30.0, + intensity_min=60.0, intensity_max=300.0, meets_target=False, + ), + ], + ), + SweepTrialResult( + multiscale="3", + sigma=2.5, + trial_index=2, + success=True, + n5_output_path="/scratch/trial_2", + view_metrics=sample_view_metrics, + ), + ] + + +# --- ViewIPMetrics tests --- + + +class TestViewIPMetrics: + def test_to_dict(self, sample_view_metrics): + d = sample_view_metrics[0].to_dict() + assert d["view_id"] == "tpId_0_viewSetupId_0" + assert d["ip_count"] == 500 + assert d["meets_target"] is True + assert isinstance(d["spatial_extent_xyz"], list) + + def test_to_metric_list(self, sample_view_metrics): + metrics = sample_view_metrics[0].to_metric_list() + assert isinstance(metrics, list) + assert len(metrics) == 13 + names = {m["name"] for m in metrics} + assert "ip_count" in names + assert "density" in names + assert "meets_target" in names + for m in metrics: + assert "name" in m + assert "value" in m + assert "description" in m + assert "view_id" in m + + def test_frozen(self, sample_view_metrics): + with pytest.raises(AttributeError): + sample_view_metrics[0].ip_count = 999 + + +class TestGetIPCountFromAttributes: + def test_valid_attributes(self, tmp_path): + attrs_path = tmp_path / "attributes.json" + with open(attrs_path, "w") as f: + json.dump({"dimensions": [3, 42]}, f) + assert _get_ip_count_from_attributes(attrs_path) == 42 + + def test_missing_file(self, tmp_path): + result = _get_ip_count_from_attributes(tmp_path / "nonexistent.json") + assert result is None + + def test_malformed_json(self, tmp_path): + attrs_path = tmp_path / "attributes.json" + attrs_path.write_text("{bad json") + assert _get_ip_count_from_attributes(attrs_path) is None + + +class TestComputeAllViewMetrics: + def test_discovers_views(self, mock_n5_dir): + metrics = compute_all_view_metrics( + str(mock_n5_dir), target_interest_points=2, compute_spatial=False, + ) + assert len(metrics) == 2 + view_ids = {m.view_id for m in metrics} + assert "tpId_0_viewSetupId_0" in view_ids + assert "tpId_0_viewSetupId_1" in view_ids + + def test_meets_target_flag(self, mock_n5_dir): + metrics = compute_all_view_metrics( + str(mock_n5_dir), target_interest_points=3, compute_spatial=False, + ) + counts = {m.view_id: m.ip_count for m in metrics} + targets = {m.view_id: m.meets_target for m in metrics} + assert targets["tpId_0_viewSetupId_0"] is True # 3 IPs == 3 target + assert targets["tpId_0_viewSetupId_1"] is False # 2 IPs < 3 target + + def test_empty_dir(self, mock_empty_n5_dir): + metrics = compute_all_view_metrics( + str(mock_empty_n5_dir), target_interest_points=10, + ) + assert metrics == [] + + +# --- SweepTrialResult tests --- + + +class TestSweepTrialResult: + def test_properties(self, sample_trials): + success_trial = sample_trials[2] + assert success_trial.total_ip_count == 600 + assert success_trial.views_meeting_target == 1 + assert success_trial.num_views == 2 + assert success_trial.success_rate == 0.5 + assert success_trial.mean_ip_count == 300.0 + + def test_empty_trial(self): + trial = SweepTrialResult( + multiscale="5", sigma=4.0, trial_index=0, + success=False, n5_output_path="/scratch/empty", + ) + assert trial.total_ip_count == 0 + assert trial.success_rate == 0.0 + assert trial.mean_ip_count == 0.0 + assert trial.mean_density == 0.0 + + def test_to_dict(self, sample_trials): + d = sample_trials[0].to_dict() + assert d["multiscale"] == "5" + assert d["sigma"] == 4.0 + assert "metrics" in d + assert "view_metrics" in d + + def test_to_metric_list(self, sample_trials): + metrics = sample_trials[2].to_metric_list() + assert len(metrics) == 9 + names = {m["name"] for m in metrics} + assert "success_rate" in names + assert "total_ip_count" in names + + +# --- SweepAnalyzer tests --- + + +class TestSweepAnalyzer: + def test_get_selected_trial(self, sample_trials): + analyzer = SweepAnalyzer(sample_trials, target_interest_points=300) + selected = analyzer.get_selected_trial() + assert selected is not None + assert selected.trial_index == 2 + assert selected.multiscale == "3" + + def test_no_success(self): + trials = [ + SweepTrialResult( + multiscale="5", sigma=4.0, trial_index=0, + success=False, n5_output_path="/scratch/0", + ), + ] + analyzer = SweepAnalyzer(trials, target_interest_points=500) + assert analyzer.get_selected_trial() is None + + def test_get_summary_structure(self, sample_trials): + analyzer = SweepAnalyzer(sample_trials, target_interest_points=300) + summary = analyzer.get_summary() + + assert "summary_metrics" in summary + assert "trials" in summary + assert len(summary["trials"]) == 3 + + metrics = summary["summary_metrics"] + names = {m["name"]: m for m in metrics} + assert names["num_trials_attempted"]["value"] == 3 + assert names["num_trials_succeeded"]["value"] == 1 + assert names["first_success_trial_index"]["value"] == 2 + assert names["selected_multiscale"]["value"] == "3" + assert names["selected_sigma"]["value"] == 2.5 + + def test_get_summary_all_failed(self): + trials = [ + SweepTrialResult( + multiscale="5", sigma=4.0, trial_index=0, + success=False, n5_output_path="/scratch/0", + ), + ] + analyzer = SweepAnalyzer(trials, target_interest_points=500) + summary = analyzer.get_summary() + names = {m["name"]: m for m in summary["summary_metrics"]} + assert names["selected_multiscale"]["value"] is None + assert names["selected_sigma"]["value"] is None + assert names["first_success_trial_index"]["value"] is None + + def test_get_all_view_metrics_flat(self, sample_trials): + analyzer = SweepAnalyzer(sample_trials, target_interest_points=300) + flat = analyzer.get_all_view_metrics_flat() + assert len(flat) == 4 # 1 + 1 + 2 views + assert all("trial_multiscale" in entry for entry in flat) + assert all("trial_sigma" in entry for entry in flat) + + def test_summary_is_json_serializable(self, sample_trials): + analyzer = SweepAnalyzer(sample_trials, target_interest_points=300) + summary = analyzer.get_summary() + # Should not raise + json_str = json.dumps(summary) + assert isinstance(json_str, str) + + +# --- Plotting tests --- + + +class TestPlotting: + def test_generate_all_plots(self, sample_trials, tmp_path): + paths = generate_all_plots( + trials=sample_trials, + target_interest_points=300, + output_dir=str(tmp_path / "plots"), + ) + assert len(paths) >= 2 + for p in paths: + assert os.path.exists(p) + assert p.endswith(".png") + + def test_plot_sweep_ip_counts(self, sample_trials, tmp_path): + path = plot_sweep_ip_counts( + sample_trials, target_interest_points=300, + output_dir=str(tmp_path), + ) + assert os.path.exists(path) + + def test_plot_sweep_success_rates(self, sample_trials, tmp_path): + path = plot_sweep_success_rates( + sample_trials, output_dir=str(tmp_path), + ) + assert os.path.exists(path) + + def test_empty_trials(self, tmp_path): + paths = generate_all_plots( + trials=[], target_interest_points=300, + output_dir=str(tmp_path / "plots"), + ) + assert paths == [] diff --git a/tests/test_evaluaton/test_kde.py b/tests/test_evaluation/test_kde.py similarity index 100% rename from tests/test_evaluaton/test_kde.py rename to tests/test_evaluation/test_kde.py diff --git a/tests/test_evaluaton/test_matching_description_stats.py b/tests/test_evaluation/test_matching_description_stats.py similarity index 100% rename from tests/test_evaluaton/test_matching_description_stats.py rename to tests/test_evaluation/test_matching_description_stats.py diff --git a/tests/test_evaluaton/test_threshold.py b/tests/test_evaluation/test_threshold.py similarity index 100% rename from tests/test_evaluaton/test_threshold.py rename to tests/test_evaluation/test_threshold.py diff --git a/tests/test_evaluaton/test_voxel_vis.py b/tests/test_evaluation/test_voxel_vis.py similarity index 100% rename from tests/test_evaluaton/test_voxel_vis.py rename to tests/test_evaluation/test_voxel_vis.py diff --git a/tests/test_evaluaton/test_voxelization.py b/tests/test_evaluation/test_voxelization.py similarity index 100% rename from tests/test_evaluaton/test_voxelization.py rename to tests/test_evaluation/test_voxelization.py diff --git a/tests/test_matching/test_load_and_transform_points.py b/tests/test_matching/test_load_and_transform_points.py new file mode 100644 index 0000000..adf063a --- /dev/null +++ b/tests/test_matching/test_load_and_transform_points.py @@ -0,0 +1,120 @@ +import unittest + +import numpy as np + +from Rhapso.matching.load_and_transform_points import ( + SPLIT_TILE_TRANSFORM_NAME, + LoadAndTransformPoints, +) + + +def _translation_transform(name, tx, ty, tz): + """Build a view_registrations entry (name + 3x4 affine string) for a + pure translation (tx, ty, tz).""" + affine = f"1.0 0.0 0.0 {tx} 0.0 1.0 0.0 {ty} 0.0 0.0 1.0 {tz}" + return {"type": "affine", "name": name, "affine": affine} + + +def _scale_transform(name, sx, sy, sz): + affine = f"{sx} 0.0 0.0 0.0 0.0 {sy} 0.0 0.0 0.0 0.0 {sz} 0.0" + return {"type": "affine", "name": name, "affine": affine} + + +class TestImageSplittingSkip(unittest.TestCase): + """The 'Image Splitting' ViewTransform must NOT be composed by + ``get_transformation_matrix``. Detection has already baked the + split-tile's world translation into the stored N5 IP coords. + """ + + def setUp(self): + self.loader = LoadAndTransformPoints( + data_global={}, + xml_input_path="/dev/null", + n5_output_path="", + match_type="split_affine", + ) + + def test_image_splitting_transform_is_skipped(self): + view_id = (0, 1) + view_registrations = { + view_id: [ + _scale_transform("calibration", 1.0, 1.0, 3.866), + _translation_transform( + SPLIT_TILE_TRANSFORM_NAME, 384.0, 0.0, 0.0 + ), + ] + } + M = self.loader.get_transformation_matrix(view_id, view_registrations) + + # Only calibration should apply; translation from Image Splitting + # must be skipped. Check translation column = 0. + np.testing.assert_array_almost_equal(M[:3, 3], [0.0, 0.0, 0.0]) + # Scale column reflects calibration. + np.testing.assert_array_almost_equal( + np.diag(M[:3, :3]), [1.0, 1.0, 3.866] + ) + + def test_other_named_translation_still_applied(self): + """Non-'Image Splitting' transforms must still compose. This + guards against over-broad filtering.""" + view_id = (0, 1) + view_registrations = { + view_id: [ + _scale_transform("calibration", 1.0, 1.0, 3.866), + _translation_transform("Stitching Solver", 10.0, 20.0, 30.0), + ] + } + M = self.loader.get_transformation_matrix(view_id, view_registrations) + np.testing.assert_array_almost_equal( + M[:3, 3], [10.0, 20.0, 30.0 * 3.866] + ) + + def test_correspondence_delta_invariant_to_split_translation(self): + """End-to-end: a moving point at world (920, 234, 139.1) and a + fixed point at the identical world position must produce a + zero delta after transform composition, regardless of whether + the moving view has an 'Image Splitting' translation in its + XML chain. Without the skip, the delta would equal the split + translation. + """ + mov_view = (0, 1) + fix_view = (0, 16) + view_registrations = { + mov_view: [ + _scale_transform("calibration", 1.0, 1.0, 1.0), + _translation_transform( + SPLIT_TILE_TRANSFORM_NAME, 384.0, 0.0, 0.0 + ), + ], + fix_view: [ + _scale_transform("calibration", 1.0, 1.0, 1.0), + ], + } + + mov_pts_world = np.array([[920.0, 234.0, 139.1]]) + fix_pts_world = np.array([[920.0, 234.0, 139.1]]) + + M_mov = self.loader.get_transformation_matrix(mov_view, view_registrations) + M_fix = self.loader.get_transformation_matrix(fix_view, view_registrations) + + mov_t = self.loader.transform_interest_points(mov_pts_world, M_mov) + fix_t = self.loader.transform_interest_points(fix_pts_world, M_fix) + + delta = mov_t[0] - fix_t[0] + np.testing.assert_array_almost_equal(delta, [0.0, 0.0, 0.0]) + + def test_case_sensitive_name_match(self): + """The reserved name is an exact string match; a subtly + different name (e.g. lowercase) should NOT be skipped.""" + view_id = (0, 1) + view_registrations = { + view_id: [ + _translation_transform("image splitting", 384.0, 0.0, 0.0), + ] + } + M = self.loader.get_transformation_matrix(view_id, view_registrations) + np.testing.assert_array_almost_equal(M[:3, 3], [384.0, 0.0, 0.0]) + + +if __name__ == "__main__": + unittest.main()