Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 104 additions & 18 deletions hexrd/core/imageseries/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,31 @@
NOTES
-----
* Perhaps we should rename min -> minimum and max -> maximum to avoid
conflicting with the python built-ins
conflicting with the python built-ins (likewise for sum)
"""

from __future__ import annotations

from typing import Iterator, TYPE_CHECKING

import numpy as np
import numpy.typing as npt

from psutil import virtual_memory

if TYPE_CHECKING:
from .imageseriesabc import ImageSeriesABC

# Either a raw array (frames along axis 0) or an imageseries adapter.
ImageInput = ImageSeriesABC | np.ndarray

# Default Buffer Size: half of available memory
vmem = virtual_memory()
STATS_BUFFER = int(0.5 * vmem.available)
del vmem


def max(ims, nframes=0):
def max(ims: ImageInput, nframes: int = 0) -> np.ndarray:
"""maximum over frames"""
nf = _nframes(ims, nframes)
img = ims[0]
Expand All @@ -43,7 +54,9 @@ def max(ims, nframes=0):
return img


def max_iter(ims, nchunk, nframes=0):
def max_iter(
ims: ImageInput, nchunk: int, nframes: int = 0
) -> Iterator[np.ndarray]:
"""iterator for max function"""
nf = _nframes(ims, nframes)
stops = _chunk_stops(nf, nchunk)
Expand All @@ -64,7 +77,7 @@ def max_iter(ims, nchunk, nframes=0):
stop = stops[s0]


def min(ims, nframes=0):
def min(ims: ImageInput, nframes: int = 0) -> np.ndarray:
"""minimum over frames"""
nf = _nframes(ims, nframes)
img = ims[0]
Expand All @@ -73,7 +86,9 @@ def min(ims, nframes=0):
return img


def min_iter(ims, nchunk, nframes=0):
def min_iter(
ims: ImageInput, nchunk: int, nframes: int = 0
) -> Iterator[np.ndarray]:
"""iterator for min function"""
nf = _nframes(ims, nframes)
stops = _chunk_stops(nf, nchunk)
Expand All @@ -93,16 +108,51 @@ def min_iter(ims, nchunk, nframes=0):
stop = stops[s0]


def average(ims, nframes=0):
"""average over frames"""
def sum(ims: ImageInput, nframes: int = 0) -> np.ndarray:
"""sum over frames

Accumulates in float64 so the total can exceed the input dtype's range
(e.g. uint16/uint32) without overflowing.
"""
img, _ = _accumulate(ims, nframes, np.float64)
return img


def sum_iter(
ims: ImageInput, nchunk: int, nframes: int = 0
) -> Iterator[np.ndarray]:
"""iterator for sum function

Note: sum accumulates in float64 even if the images are integer-typed.
"""
nf = _nframes(ims, nframes)
img = ims[0].astype(np.float32)
stops = _chunk_stops(nf, nchunk)
s0, stop = 0, stops[0]
img = ims[0].astype(np.float64)
if stop == 0:
if nf > 1:
s0, stop = 1, stops[1]
# Copy so later in-place accumulation can't mutate the yielded array.
yield img.copy()

for i in range(1, nf):
img += ims[i]
if i >= stop:
if (i + 1) < nf:
s0 += 1
stop = stops[s0]
yield img.copy()


def average(ims: ImageInput, nframes: int = 0) -> np.ndarray:
"""average over frames"""
img, nf = _accumulate(ims, nframes, np.float32)
return img / nf


def average_iter(ims, nchunk, nframes=0):
def average_iter(
ims: ImageInput, nchunk: int, nframes: int = 0
) -> Iterator[np.ndarray]:
"""average over frames

Note: average returns a float even if images are uint
Expand All @@ -125,7 +175,9 @@ def average_iter(ims, nchunk, nframes=0):
yield img / (i + 1)


def percentile(ims, pctl, nframes=0):
def percentile(
ims: ImageInput, pctl: float, nframes: int = 0
) -> np.ndarray:
"""percentile function over frames

ims - the imageseries
Expand All @@ -136,7 +188,13 @@ def percentile(ims, pctl, nframes=0):
return np.percentile(_toarray(ims, nf), pctl, axis=0).astype(np.float32)


def percentile_iter(ims, pctl, nchunks, nframes=0, use_buffer=True):
def percentile_iter(
ims: ImageInput,
pctl: float,
nchunks: int,
nframes: int = 0,
use_buffer: bool = True,
) -> Iterator[np.ndarray]:
"""iterator for percentile function"""
nf = _nframes(ims, nframes)
nr, nc = ims.shape
Expand All @@ -153,23 +211,46 @@ def percentile_iter(ims, pctl, nchunks, nframes=0, use_buffer=True):
yield img.astype(np.float32)


def median(ims, nframes=0):
def median(ims: ImageInput, nframes: int = 0) -> np.ndarray:
return percentile(ims, 50, nframes=nframes)


def median_iter(ims, nchunks, nframes=0, use_buffer=True):
return percentile_iter(ims, 50, nchunks, nframes=nframes, use_buffer=use_buffer)
def median_iter(
ims: ImageInput,
nchunks: int,
nframes: int = 0,
use_buffer: bool = True,
) -> Iterator[np.ndarray]:
return percentile_iter(
ims, 50, nchunks, nframes=nframes, use_buffer=use_buffer
)


# ==================== Utilities
#
def _nframes(ims, nframes):
def _accumulate(
ims: ImageInput, nframes: int, dtype: npt.DTypeLike
) -> tuple[np.ndarray, int]:
"""sum frames into a single image of the given accumulator dtype

Returns the accumulated image and the number of frames used. The dtype
controls precision/overflow behavior (e.g. float32 for average, float64
for sum).
"""
nf = _nframes(ims, nframes)
img = ims[0].astype(dtype)
for i in range(1, nf):
img += ims[i]
return img, nf


def _nframes(ims: ImageInput, nframes: int) -> int:
"""number of frames to use: len(ims) or specified number"""
mynf = len(ims)
return np.min((mynf, nframes)) if nframes > 0 else mynf


def _chunk_stops(n, nchunks):
def _chunk_stops(n: int, nchunks: int) -> npt.NDArray[np.int_]:
"""Return yield points

n -- number of items to be chunked (e.g. frames/rows)
Expand All @@ -186,7 +267,12 @@ def _chunk_stops(n, nchunks):
return np.cumsum(pieces)


def _toarray(ims, nframes, rows=None, buffer=None):
def _toarray(
ims: ImageInput,
nframes: int,
rows: tuple[int, int] | None = None,
buffer: np.ndarray | None = None,
) -> np.ndarray:
"""generate array for either whole imageseries or subset of rows

ims - imageseries
Expand Down Expand Up @@ -226,7 +312,7 @@ def _toarray(ims, nframes, rows=None, buffer=None):
return a


def _alloc_buffer(ims, nf):
def _alloc_buffer(ims: ImageInput, nf: int) -> np.ndarray:
"""Allocate buffer to save as many full frames as possible"""
shp, dt = ims.shape, ims.dtype
framesize = shp[0] * shp[1] * dt.itemsize
Expand Down
38 changes: 38 additions & 0 deletions tests/core/imageseries/test_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,40 @@ def test_average_iter(mock_ims):
assert np.all(res2 == 4.5)


# --- Sum Tests ---


def test_sum(mock_ims: MagicMock) -> None:
res = stats.sum(mock_ims)
assert np.all(res == 45.0)
# Accumulates in float64 regardless of input dtype.
assert res.dtype == np.float64

res = stats.sum(mock_ims, nframes=5)
assert np.all(res == 10.0)


def test_sum_iter(mock_ims: MagicMock) -> None:
gen = stats.sum_iter(mock_ims, nchunk=2)

res1 = next(gen)
assert np.all(res1 == 10.0)

res2 = next(gen)
assert np.all(res2 == 45.0)

with pytest.raises(StopIteration):
next(gen)


def test_sum_matches_average(mock_ims: MagicMock) -> None:
# sum should equal average * nframes, while average stays float32.
avg = stats.average(mock_ims)
total = stats.sum(mock_ims)
assert np.all(total == avg * len(mock_ims))
assert avg.dtype == np.float32


def test_iterators_stop_zero_edge_case(mock_ims):
nchunks = 10

Expand All @@ -97,6 +131,10 @@ def test_iterators_stop_zero_edge_case(mock_ims):
res0 = next(gen_avg)
assert np.all(res0 == 0.0)

gen_sum = stats.sum_iter(mock_ims, nchunk=nchunks)
res0 = next(gen_sum)
assert np.all(res0 == 0.0)


# --- Percentile & Median Tests ---

Expand Down
40 changes: 40 additions & 0 deletions tests/imageseries/test_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,46 @@ def test_stats_average(self):
self.assertAlmostEqual(err, 0.0, msg="stats.average failed")
self.assertEqual(is_avg.dtype, np.float32)

def test_stats_sum(self) -> None:
"""imageseries.stats: sum

Compares with numpy sum
"""
a, is_a = make_array_ims()
is_sum = stats.sum(is_a)
np_sum = np.sum(a, axis=0)
err = np.linalg.norm(np_sum - is_sum)
self.assertAlmostEqual(err, 0.0, msg="stats.sum failed")
# Accumulates in float64 so totals can exceed the input dtype range.
self.assertEqual(is_sum.dtype, np.float64)

def test_stats_sum_chunked(self) -> None:
"""imageseries.stats: chunked sum"""
a, is_a = make_array_ims()
a_sum = stats.sum(is_a)

# Run with 1 chunk
for issum1 in stats.sum_iter(is_a, 1):
pass
err = np.linalg.norm(a_sum - issum1)
self.assertAlmostEqual(err, 0.0, msg="stats.sum failed (1 chunk)")

# Run with 2 chunks
for issum2 in stats.sum_iter(is_a, 2):
pass
err = np.linalg.norm(a_sum - issum2)
self.assertAlmostEqual(err, 0.0, msg="stats.sum failed (2 chunks)")

def test_stats_sum_average_consistency(self) -> None:
"""imageseries.stats: sum == average * nframes"""
a, is_a = make_array_ims()
is_sum = stats.sum(is_a)
is_avg = stats.average(is_a)
err = np.linalg.norm(is_sum - is_avg * len(is_a))
self.assertAlmostEqual(err, 0.0, msg="sum/average inconsistent")
# average stays float32 even though sum is float64.
self.assertEqual(is_avg.dtype, np.float32)

def test_stats_median(self):
"""imageseries.stats: median"""
a, is_a = make_array_ims()
Expand Down
Loading