diff --git a/param/parameterized.py b/param/parameterized.py index d27b7831..c6574940 100644 --- a/param/parameterized.py +++ b/param/parameterized.py @@ -17,6 +17,7 @@ import numbers import operator import os +import pathlib import re import sys import types @@ -2337,17 +2338,61 @@ class Comparator: To use the Comparator simply call the is_equal function. """ + # Maximum number of elements for array/frame equality checks. + # Above this threshold the comparator gives up and returns False + # (triggering watchers) rather than doing an expensive comparison. + array_max_size = 1_000_000 + equalities: dict[type | tuple[type, ...] | Callable[[t.Any], bool], Callable[[t.Any, t.Any], bool]] = { numbers.Number: operator.eq, str: operator.eq, bytes: operator.eq, type(None): operator.eq, + pathlib.PurePath: operator.eq, lambda o: hasattr(o, '_infinitely_iterable'): operator.eq, # Time + lambda o: type(o).__module__.startswith('numpy') and all(hasattr(o, a) for a in ('shape', 'dtype', 'size')): lambda a, b: Comparator._array_equal(a, b), + lambda o: type(o).__module__.startswith('pandas') and all(hasattr(o, a) for a in ('shape', 'size', 'equals')): lambda a, b: Comparator._pandas_equal(a, b), } gen_equalities = { _dt_types: operator.eq } + @staticmethod + def _array_equal(obj1, obj2): + """Equality check for numpy arrays with a size cutoff.""" + if obj1 is obj2: + return True + if type(obj1) is not type(obj2): + return False + try: + if obj1.shape != obj2.shape or obj1.dtype != obj2.dtype: + return False + if obj1.size > Comparator.array_max_size: + return False + except AttributeError: + return False + try: + import numpy as np + return bool(np.array_equal(obj1, obj2)) + except (ImportError, ValueError, TypeError): + return False + + @staticmethod + def _pandas_equal(obj1, obj2): + """Equality check for pandas DataFrame/Series with a size cutoff.""" + if obj1 is obj2: + return True + if type(obj1) is not type(obj2): + return False + try: + if obj1.shape != obj2.shape: + return False + if obj1.size > Comparator.array_max_size: + return False + return bool(obj1.equals(obj2)) + except (AttributeError, ValueError, TypeError): + return False + @classmethod def is_equal(cls, obj1, obj2): equals = cls.equalities.copy() diff --git a/tests/testcomparator.py b/tests/testcomparator.py index 20faf522..0e428dba 100644 --- a/tests/testcomparator.py +++ b/tests/testcomparator.py @@ -1,5 +1,6 @@ import datetime import decimal +import pathlib import pytest @@ -30,15 +31,142 @@ 'dict': {'a': 1, 'b': 2}, 'date': _today, 'datetime': _now, + 'pathlib.Path': pathlib.Path('/tmp/test'), + 'pathlib.PurePosixPath': pathlib.PurePosixPath('/tmp/test'), } if np: _supported.update({ 'np.datetime64': np.datetime64(_now), + 'np.array_int': np.array([1, 2, 3]), + 'np.array_float': np.array([1.0, 2.0]), + 'np.array_2d': np.zeros((3, 4)), }) if pd: - _supported.update({'pd.Timestamp': pd.Timestamp(_now)}) + _supported.update({ + 'pd.Timestamp': pd.Timestamp(_now), + 'pd.Series': pd.Series([1, 2, 3]), + 'pd.DataFrame': pd.DataFrame({'a': [1, 2], 'b': [3, 4]}), + }) @pytest.mark.parametrize('obj', _supported.values(), ids=_supported.keys()) def test_comparator_equal(obj): assert Comparator.is_equal(obj, obj) + + +# ---- pathlib tests ---- + +def test_path_equal(): + assert Comparator.is_equal(pathlib.Path('/a/b'), pathlib.Path('/a/b')) + +def test_path_not_equal(): + assert not Comparator.is_equal(pathlib.Path('/a/b'), pathlib.Path('/a/c')) + +def test_purepath_equal(): + assert Comparator.is_equal(pathlib.PurePosixPath('/x'), pathlib.PurePosixPath('/x')) + + +# ---- numpy tests ---- + +@pytest.mark.skipif(np is None, reason='numpy not available') +class TestComparatorNumpy: + + def test_array_equal(self): + a = np.array([1, 2, 3]) + b = np.array([1, 2, 3]) + assert Comparator.is_equal(a, b) + + def test_array_not_equal(self): + a = np.array([1, 2, 3]) + b = np.array([1, 2, 4]) + assert not Comparator.is_equal(a, b) + + def test_array_different_shape(self): + a = np.array([1, 2, 3]) + b = np.array([[1, 2, 3]]) + assert not Comparator.is_equal(a, b) + + def test_array_different_dtype(self): + a = np.array([1, 2], dtype=np.int32) + b = np.array([1, 2], dtype=np.float64) + assert not Comparator.is_equal(a, b) + + def test_array_identity(self): + a = np.array([1, 2, 3]) + assert Comparator.is_equal(a, a) + + def test_array_large_skips(self): + """Arrays larger than array_max_size should return False.""" + old = Comparator.array_max_size + try: + Comparator.array_max_size = 5 + a = np.arange(10) + b = np.arange(10) + assert not Comparator.is_equal(a, b) + finally: + Comparator.array_max_size = old + + def test_array_with_nan(self): + a = np.array([1.0, float('nan'), 3.0]) + b = np.array([1.0, float('nan'), 3.0]) + # np.array_equal treats NaN == NaN as False + assert not Comparator.is_equal(a, b) + + +# ---- pandas tests ---- + +@pytest.mark.skipif(pd is None, reason='pandas not available') +class TestComparatorPandas: + + def test_series_equal(self): + a = pd.Series([1, 2, 3]) + b = pd.Series([1, 2, 3]) + assert Comparator.is_equal(a, b) + + def test_series_not_equal(self): + a = pd.Series([1, 2, 3]) + b = pd.Series([1, 2, 4]) + assert not Comparator.is_equal(a, b) + + def test_dataframe_equal(self): + a = pd.DataFrame({'x': [1, 2], 'y': [3, 4]}) + b = pd.DataFrame({'x': [1, 2], 'y': [3, 4]}) + assert Comparator.is_equal(a, b) + + def test_dataframe_not_equal(self): + a = pd.DataFrame({'x': [1, 2]}) + b = pd.DataFrame({'x': [1, 3]}) + assert not Comparator.is_equal(a, b) + + def test_dataframe_different_shape(self): + a = pd.DataFrame({'x': [1, 2]}) + b = pd.DataFrame({'x': [1, 2], 'y': [3, 4]}) + assert not Comparator.is_equal(a, b) + + def test_dataframe_with_nan(self): + a = pd.DataFrame({'x': [1.0, float('nan')]}) + b = pd.DataFrame({'x': [1.0, float('nan')]}) + # pd.DataFrame.equals treats NaN as equal + assert Comparator.is_equal(a, b) + + def test_series_large_skips(self): + """Series larger than array_max_size should return False.""" + old = Comparator.array_max_size + try: + Comparator.array_max_size = 5 + a = pd.Series(range(10)) + b = pd.Series(range(10)) + assert not Comparator.is_equal(a, b) + finally: + Comparator.array_max_size = old + + def test_dataframe_large_skips(self): + """DataFrames larger than array_max_size should return False.""" + old = Comparator.array_max_size + try: + Comparator.array_max_size = 5 + a = pd.DataFrame({'x': range(10)}) + b = pd.DataFrame({'x': range(10)}) + assert not Comparator.is_equal(a, b) + finally: + Comparator.array_max_size = old