diff --git a/src/supervision/metrics/mota.py b/src/supervision/metrics/mota.py new file mode 100644 index 0000000000..782d49534f --- /dev/null +++ b/src/supervision/metrics/mota.py @@ -0,0 +1,251 @@ +from __future__ import annotations + +from dataclasses import dataclass + +import numpy as np + + +@dataclass +class MOTAResult: + mota: float + num_false_positives: int + num_false_negatives: int + num_id_switches: int + num_ground_truth: int + num_frames: int + + def __str__(self): + return ( + f"MOTAResult(mota={self.mota:.4f}, " + f"fp={self.num_false_positives}, " + f"fn={self.num_false_negatives}, " + f"idsw={self.num_id_switches})" + ) + + +@dataclass +class MOTPResult: + motp: float + total_iou: float + num_matches: int + num_frames: int + + def __str__(self): + return f"MOTPResult(motp={self.motp:.4f}, matches={self.num_matches})" + + +@dataclass +class MOTMetricsResult: + mota: float + motp: float + num_false_positives: int + num_false_negatives: int + num_id_switches: int + num_ground_truth: int + num_matches: int + num_frames: int + + def __str__(self): + return ( + f"MOTMetricsResult(mota={self.mota:.4f}, " + f"motp={self.motp:.4f}, " + f"fp={self.num_false_positives}, " + f"fn={self.num_false_negatives}, " + f"idsw={self.num_id_switches})" + ) + + +def _compute_iou_matrix(boxes_a, boxes_b): + x1 = np.maximum(boxes_a[:, 0][:, None], boxes_b[:, 0][None, :]) + y1 = np.maximum(boxes_a[:, 1][:, None], boxes_b[:, 1][None, :]) + x2 = np.minimum(boxes_a[:, 2][:, None], boxes_b[:, 2][None, :]) + y2 = np.minimum(boxes_a[:, 3][:, None], boxes_b[:, 3][None, :]) + intersection = np.maximum(0, x2 - x1) * np.maximum(0, y2 - y1) + area_a = (boxes_a[:, 2] - boxes_a[:, 0]) * (boxes_a[:, 3] - boxes_a[:, 1]) + area_b = (boxes_b[:, 2] - boxes_b[:, 0]) * (boxes_b[:, 3] - boxes_b[:, 1]) + union = area_a[:, None] + area_b[None, :] - intersection + return np.where(union > 0, intersection / union, 0.0) + + +def _greedy_match(iou_matrix, threshold): + matches = [] + matched_gt = set() + matched_pred = set() + num_gt, num_pred = iou_matrix.shape + flat_indices = np.argsort(-iou_matrix.ravel()) + for flat_idx in flat_indices: + gt_idx = int(flat_idx // num_pred) + pred_idx = int(flat_idx % num_pred) + if iou_matrix[gt_idx, pred_idx] < threshold: + break + if gt_idx in matched_gt or pred_idx in matched_pred: + continue + matches.append((gt_idx, pred_idx)) + matched_gt.add(gt_idx) + matched_pred.add(pred_idx) + return matched_gt, matched_pred, matches + + +def _validate_tracker_ids(ground_truth, predictions): + if ground_truth.tracker_id is None: + raise ValueError("ground_truth.tracker_id must be set for tracking metrics.") + if predictions.tracker_id is None: + raise ValueError("predictions.tracker_id must be set for tracking metrics.") + + +class MultiObjectTrackingAccuracy: + def __init__(self, iou_threshold=0.5): + if not 0 < iou_threshold <= 1: + raise ValueError(f"iou_threshold must be in (0, 1], got {iou_threshold}") + self.iou_threshold = iou_threshold + self.reset() + + def reset(self): + self._false_positives = 0 + self._false_negatives = 0 + self._id_switches = 0 + self._ground_truth_count = 0 + self._num_frames = 0 + self._last_match = {} + + def update(self, ground_truth, predictions): + _validate_tracker_ids(ground_truth, predictions) + self._num_frames += 1 + num_gt = len(ground_truth) + num_pred = len(predictions) + self._ground_truth_count += num_gt + if num_gt == 0 and num_pred == 0: + return + if num_gt == 0: + self._false_positives += num_pred + return + if num_pred == 0: + self._false_negatives += num_gt + return + iou_matrix = _compute_iou_matrix(ground_truth.xyxy, predictions.xyxy) + _, _, matches = _greedy_match(iou_matrix, self.iou_threshold) + self._false_negatives += num_gt - len(matches) + self._false_positives += num_pred - len(matches) + for gt_idx, pred_idx in matches: + gt_id = int(ground_truth.tracker_id[gt_idx]) + pred_id = int(predictions.tracker_id[pred_idx]) + if gt_id in self._last_match: + if self._last_match[gt_id] != pred_id: + self._id_switches += 1 + self._last_match[gt_id] = pred_id + + def compute(self): + if self._ground_truth_count == 0: + raise ValueError("No ground truth objects found. Call update() first.") + total_errors = self._false_negatives + self._false_positives + self._id_switches + mota = 1.0 - total_errors / self._ground_truth_count + return MOTAResult( + mota=mota, + num_false_positives=self._false_positives, + num_false_negatives=self._false_negatives, + num_id_switches=self._id_switches, + num_ground_truth=self._ground_truth_count, + num_frames=self._num_frames, + ) + + +class MultiObjectTrackingPrecision: + def __init__(self, iou_threshold=0.5): + if not 0 < iou_threshold <= 1: + raise ValueError(f"iou_threshold must be in (0, 1], got {iou_threshold}") + self.iou_threshold = iou_threshold + self.reset() + + def reset(self): + self._total_iou = 0.0 + self._num_matches = 0 + self._num_frames = 0 + + def update(self, ground_truth, predictions): + _validate_tracker_ids(ground_truth, predictions) + self._num_frames += 1 + num_gt = len(ground_truth) + num_pred = len(predictions) + if num_gt == 0 or num_pred == 0: + return + iou_matrix = _compute_iou_matrix(ground_truth.xyxy, predictions.xyxy) + _, _, matches = _greedy_match(iou_matrix, self.iou_threshold) + for gt_idx, pred_idx in matches: + self._total_iou += iou_matrix[gt_idx, pred_idx] + self._num_matches += 1 + + def compute(self): + if self._num_matches == 0: + raise ValueError( + "No matches found. Call update() with matching detections first." + ) + motp = self._total_iou / self._num_matches + return MOTPResult( + motp=motp, + total_iou=self._total_iou, + num_matches=self._num_matches, + num_frames=self._num_frames, + ) + + +class TrackingMetrics: + def __init__(self, iou_threshold=0.5): + if not 0 < iou_threshold <= 1: + raise ValueError(f"iou_threshold must be in (0, 1], got {iou_threshold}") + self.iou_threshold = iou_threshold + self.reset() + + def reset(self): + self._false_positives = 0 + self._false_negatives = 0 + self._id_switches = 0 + self._ground_truth_count = 0 + self._total_iou = 0.0 + self._num_matches = 0 + self._num_frames = 0 + self._last_match = {} + + def update(self, ground_truth, predictions): + _validate_tracker_ids(ground_truth, predictions) + self._num_frames += 1 + num_gt = len(ground_truth) + num_pred = len(predictions) + self._ground_truth_count += num_gt + if num_gt == 0 and num_pred == 0: + return + if num_gt == 0: + self._false_positives += num_pred + return + if num_pred == 0: + self._false_negatives += num_gt + return + iou_matrix = _compute_iou_matrix(ground_truth.xyxy, predictions.xyxy) + _, _, matches = _greedy_match(iou_matrix, self.iou_threshold) + self._false_negatives += num_gt - len(matches) + self._false_positives += num_pred - len(matches) + for gt_idx, pred_idx in matches: + self._total_iou += iou_matrix[gt_idx, pred_idx] + self._num_matches += 1 + gt_id = int(ground_truth.tracker_id[gt_idx]) + pred_id = int(predictions.tracker_id[pred_idx]) + if gt_id in self._last_match: + if self._last_match[gt_id] != pred_id: + self._id_switches += 1 + self._last_match[gt_id] = pred_id + + def compute(self): + if self._ground_truth_count == 0: + raise ValueError("No ground truth objects found. Call update() first.") + total_errors = self._false_negatives + self._false_positives + self._id_switches + mota = 1.0 - total_errors / self._ground_truth_count + motp = self._total_iou / self._num_matches if self._num_matches > 0 else 0.0 + return MOTMetricsResult( + mota=mota, + motp=motp, + num_false_positives=self._false_positives, + num_false_negatives=self._false_negatives, + num_id_switches=self._id_switches, + num_ground_truth=self._ground_truth_count, + num_matches=self._num_matches, + num_frames=self._num_frames, + ) diff --git a/tests/metrics/test_mota.py b/tests/metrics/test_mota.py new file mode 100644 index 0000000000..5a0b0a3d9a --- /dev/null +++ b/tests/metrics/test_mota.py @@ -0,0 +1,307 @@ +import numpy as np +import pytest + +from supervision.detection.core import Detections +from supervision.metrics.mota import ( + MOTAResult, + MOTMetricsResult, + MOTPResult, + MultiObjectTrackingAccuracy, + MultiObjectTrackingPrecision, + TrackingMetrics, +) + + +def _make_detections(xyxy, tracker_ids): + return Detections( + xyxy=np.array(xyxy, dtype=np.float32), + tracker_id=np.array(tracker_ids, dtype=int), + ) + + +def _empty_detections(): + return Detections( + xyxy=np.empty((0, 4), dtype=np.float32), + tracker_id=np.array([], dtype=int), + ) + + +class TestMOTAInit: + def test_default_threshold(self): + metric = MultiObjectTrackingAccuracy() + assert metric.iou_threshold == 0.5 + + def test_custom_threshold(self): + metric = MultiObjectTrackingAccuracy(iou_threshold=0.75) + assert metric.iou_threshold == 0.75 + + def test_invalid_threshold_zero(self): + with pytest.raises(ValueError): + MultiObjectTrackingAccuracy(iou_threshold=0.0) + + def test_invalid_threshold_negative(self): + with pytest.raises(ValueError): + MultiObjectTrackingAccuracy(iou_threshold=-0.1) + + def test_invalid_threshold_above_one(self): + with pytest.raises(ValueError): + MultiObjectTrackingAccuracy(iou_threshold=1.5) + + +class TestMOTAPerfectTracking: + def test_single_frame_perfect_match(self): + metric = MultiObjectTrackingAccuracy(iou_threshold=0.5) + gt = _make_detections([[10, 10, 50, 50], [100, 100, 150, 150]], [1, 2]) + pred = _make_detections([[10, 10, 50, 50], [100, 100, 150, 150]], [1, 2]) + metric.update(gt, pred) + result = metric.compute() + assert result.mota == 1.0 + assert result.num_false_positives == 0 + assert result.num_false_negatives == 0 + assert result.num_id_switches == 0 + + def test_multi_frame_perfect_match(self): + metric = MultiObjectTrackingAccuracy(iou_threshold=0.5) + for _ in range(5): + gt = _make_detections([[10, 10, 50, 50]], [1]) + pred = _make_detections([[10, 10, 50, 50]], [1]) + metric.update(gt, pred) + result = metric.compute() + assert result.mota == 1.0 + assert result.num_frames == 5 + + +class TestMOTAFalsePositives: + def test_extra_predictions(self): + metric = MultiObjectTrackingAccuracy(iou_threshold=0.5) + gt = _make_detections([[10, 10, 50, 50]], [1]) + pred = _make_detections([[10, 10, 50, 50], [200, 200, 250, 250]], [1, 99]) + metric.update(gt, pred) + result = metric.compute() + assert result.num_false_positives == 1 + assert result.mota == 0.0 + + def test_all_false_positives(self): + metric = MultiObjectTrackingAccuracy(iou_threshold=0.5) + gt = _make_detections([[10, 10, 50, 50]], [1]) + pred = _make_detections([[200, 200, 250, 250]], [99]) + metric.update(gt, pred) + result = metric.compute() + assert result.mota == -1.0 + + +class TestMOTAFalseNegatives: + def test_missed_detections(self): + metric = MultiObjectTrackingAccuracy(iou_threshold=0.5) + gt = _make_detections([[10, 10, 50, 50], [100, 100, 150, 150]], [1, 2]) + pred = _make_detections([[10, 10, 50, 50]], [1]) + metric.update(gt, pred) + result = metric.compute() + assert result.num_false_negatives == 1 + assert result.mota == 0.5 + + def test_no_predictions(self): + metric = MultiObjectTrackingAccuracy(iou_threshold=0.5) + gt = _make_detections([[10, 10, 50, 50]], [1]) + pred = _empty_detections() + metric.update(gt, pred) + result = metric.compute() + assert result.mota == 0.0 + + +class TestMOTAIDSwitches: + def test_identity_switch(self): + metric = MultiObjectTrackingAccuracy(iou_threshold=0.5) + gt1 = _make_detections([[10, 10, 50, 50]], [1]) + pred1 = _make_detections([[10, 10, 50, 50]], [10]) + metric.update(gt1, pred1) + gt2 = _make_detections([[12, 12, 52, 52]], [1]) + pred2 = _make_detections([[12, 12, 52, 52]], [20]) + metric.update(gt2, pred2) + result = metric.compute() + assert result.num_id_switches == 1 + assert result.mota == 0.5 + + def test_no_switch_same_id(self): + metric = MultiObjectTrackingAccuracy(iou_threshold=0.5) + gt1 = _make_detections([[10, 10, 50, 50]], [1]) + pred1 = _make_detections([[10, 10, 50, 50]], [10]) + metric.update(gt1, pred1) + gt2 = _make_detections([[12, 12, 52, 52]], [1]) + pred2 = _make_detections([[12, 12, 52, 52]], [10]) + metric.update(gt2, pred2) + result = metric.compute() + assert result.num_id_switches == 0 + assert result.mota == 1.0 + + +class TestMOTAIoUThreshold: + def test_low_iou_no_match(self): + metric = MultiObjectTrackingAccuracy(iou_threshold=0.5) + gt = _make_detections([[10, 10, 50, 50]], [1]) + pred = _make_detections([[45, 45, 90, 90]], [1]) + metric.update(gt, pred) + result = metric.compute() + assert result.mota == -1.0 + + def test_strict_threshold(self): + metric = MultiObjectTrackingAccuracy(iou_threshold=0.9) + gt = _make_detections([[10, 10, 50, 50]], [1]) + pred = _make_detections([[15, 15, 55, 55]], [1]) + metric.update(gt, pred) + result = metric.compute() + assert result.num_false_negatives == 1 + + +class TestMOTAReset: + def test_reset_clears_state(self): + metric = MultiObjectTrackingAccuracy(iou_threshold=0.5) + gt = _make_detections([[10, 10, 50, 50]], [1]) + pred = _make_detections([[10, 10, 50, 50]], [1]) + metric.update(gt, pred) + metric.reset() + with pytest.raises(ValueError): + metric.compute() + + +class TestMOTAMissingTrackerID: + def test_gt_missing(self): + metric = MultiObjectTrackingAccuracy(iou_threshold=0.5) + gt = Detections(xyxy=np.array([[10, 10, 50, 50]])) + pred = _make_detections([[10, 10, 50, 50]], [1]) + with pytest.raises(ValueError): + metric.update(gt, pred) + + def test_pred_missing(self): + metric = MultiObjectTrackingAccuracy(iou_threshold=0.5) + gt = _make_detections([[10, 10, 50, 50]], [1]) + pred = Detections(xyxy=np.array([[10, 10, 50, 50]])) + with pytest.raises(ValueError): + metric.update(gt, pred) + + +class TestMOTPPerfect: + def test_exact_overlap(self): + metric = MultiObjectTrackingPrecision(iou_threshold=0.5) + gt = _make_detections([[10, 10, 50, 50]], [1]) + pred = _make_detections([[10, 10, 50, 50]], [1]) + metric.update(gt, pred) + result = metric.compute() + assert result.motp == 1.0 + + def test_multi_frame(self): + metric = MultiObjectTrackingPrecision(iou_threshold=0.5) + for _ in range(3): + gt = _make_detections([[0, 0, 100, 100]], [1]) + pred = _make_detections([[0, 0, 100, 100]], [1]) + metric.update(gt, pred) + result = metric.compute() + assert result.motp == 1.0 + assert result.num_matches == 3 + + +class TestMOTPPartialOverlap: + def test_offset_boxes(self): + metric = MultiObjectTrackingPrecision(iou_threshold=0.3) + gt = _make_detections([[0, 0, 100, 100]], [1]) + pred = _make_detections([[50, 0, 150, 100]], [1]) + metric.update(gt, pred) + result = metric.compute() + assert abs(result.motp - 1.0 / 3.0) < 1e-6 + + def test_no_match(self): + metric = MultiObjectTrackingPrecision(iou_threshold=0.5) + gt = _make_detections([[0, 0, 100, 100]], [1]) + pred = _make_detections([[200, 200, 300, 300]], [1]) + metric.update(gt, pred) + with pytest.raises(ValueError): + metric.compute() + + +class TestMOTPReset: + def test_reset(self): + metric = MultiObjectTrackingPrecision(iou_threshold=0.5) + gt = _make_detections([[10, 10, 50, 50]], [1]) + pred = _make_detections([[10, 10, 50, 50]], [1]) + metric.update(gt, pred) + metric.reset() + with pytest.raises(ValueError): + metric.compute() + + +class TestTrackingMetricsPerfect: + def test_perfect(self): + metric = TrackingMetrics(iou_threshold=0.5) + gt = _make_detections([[10, 10, 50, 50], [100, 100, 150, 150]], [1, 2]) + pred = _make_detections([[10, 10, 50, 50], [100, 100, 150, 150]], [1, 2]) + metric.update(gt, pred) + result = metric.compute() + assert result.mota == 1.0 + assert result.motp == 1.0 + assert result.num_matches == 2 + + +class TestTrackingMetricsMixed: + def test_fp_and_fn(self): + metric = TrackingMetrics(iou_threshold=0.5) + gt = _make_detections([[10, 10, 50, 50], [100, 100, 150, 150]], [1, 2]) + pred = _make_detections([[10, 10, 50, 50], [200, 200, 250, 250]], [1, 99]) + metric.update(gt, pred) + result = metric.compute() + assert result.num_false_negatives == 1 + assert result.num_false_positives == 1 + assert result.mota == 0.0 + assert result.motp == 1.0 + + def test_id_switch(self): + metric = TrackingMetrics(iou_threshold=0.5) + gt1 = _make_detections([[10, 10, 50, 50]], [1]) + pred1 = _make_detections([[10, 10, 50, 50]], [10]) + metric.update(gt1, pred1) + gt2 = _make_detections([[12, 12, 52, 52]], [1]) + pred2 = _make_detections([[12, 12, 52, 52]], [20]) + metric.update(gt2, pred2) + result = metric.compute() + assert result.num_id_switches == 1 + assert result.mota == 0.5 + + +class TestTrackingMetricsReset: + def test_reset(self): + metric = TrackingMetrics(iou_threshold=0.5) + gt = _make_detections([[10, 10, 50, 50]], [1]) + pred = _make_detections([[10, 10, 50, 50]], [1]) + metric.update(gt, pred) + metric.reset() + with pytest.raises(ValueError): + metric.compute() + + +class TestResultStrings: + def test_mota_str(self): + r = MOTAResult( + mota=0.75, + num_false_positives=5, + num_false_negatives=10, + num_id_switches=2, + num_ground_truth=68, + num_frames=20, + ) + assert "0.75" in str(r) + + def test_motp_str(self): + r = MOTPResult(motp=0.85, total_iou=8.5, num_matches=10, num_frames=5) + assert "0.85" in str(r) + + def test_combined_str(self): + r = MOTMetricsResult( + mota=0.8, + motp=0.9, + num_false_positives=3, + num_false_negatives=5, + num_id_switches=1, + num_ground_truth=45, + num_matches=36, + num_frames=10, + ) + assert "0.8" in str(r)