Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions param/parameterized.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import numbers
import operator
import os
import pathlib
import re
import sys
import types
Expand Down Expand Up @@ -2337,17 +2338,61 @@ class Comparator:
To use the Comparator simply call the is_equal function.
"""

# Maximum number of elements for array/frame equality checks.
# Above this threshold the comparator gives up and returns False
# (triggering watchers) rather than doing an expensive comparison.
array_max_size = 1_000_000

equalities: dict[type | tuple[type, ...] | Callable[[t.Any], bool], Callable[[t.Any, t.Any], bool]] = {
numbers.Number: operator.eq,
str: operator.eq,
bytes: operator.eq,
type(None): operator.eq,
pathlib.PurePath: operator.eq,
lambda o: hasattr(o, '_infinitely_iterable'): operator.eq, # Time
lambda o: type(o).__module__.startswith('numpy') and all(hasattr(o, a) for a in ('shape', 'dtype', 'size')): lambda a, b: Comparator._array_equal(a, b),
lambda o: type(o).__module__.startswith('pandas') and all(hasattr(o, a) for a in ('shape', 'size', 'equals')): lambda a, b: Comparator._pandas_equal(a, b),
}
gen_equalities = {
_dt_types: operator.eq
}

@staticmethod
def _array_equal(obj1, obj2):
"""Equality check for numpy arrays with a size cutoff."""
if obj1 is obj2:
return True
Comment on lines +2361 to +2364
if type(obj1) is not type(obj2):
return False
try:
if obj1.shape != obj2.shape or obj1.dtype != obj2.dtype:
return False
if obj1.size > Comparator.array_max_size:
return False
except AttributeError:
return False
try:
import numpy as np
return bool(np.array_equal(obj1, obj2))
except (ImportError, ValueError, TypeError):
return False

@staticmethod
def _pandas_equal(obj1, obj2):
"""Equality check for pandas DataFrame/Series with a size cutoff."""
if obj1 is obj2:
return True
if type(obj1) is not type(obj2):
return False
try:
if obj1.shape != obj2.shape:
return False
if obj1.size > Comparator.array_max_size:
return False
return bool(obj1.equals(obj2))
except (AttributeError, ValueError, TypeError):
return False

@classmethod
def is_equal(cls, obj1, obj2):
equals = cls.equalities.copy()
Expand Down
130 changes: 129 additions & 1 deletion tests/testcomparator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import datetime
import decimal
import pathlib

import pytest

Expand Down Expand Up @@ -30,15 +31,142 @@
'dict': {'a': 1, 'b': 2},
'date': _today,
'datetime': _now,
'pathlib.Path': pathlib.Path('/tmp/test'),
'pathlib.PurePosixPath': pathlib.PurePosixPath('/tmp/test'),
}

if np:
_supported.update({
'np.datetime64': np.datetime64(_now),
'np.array_int': np.array([1, 2, 3]),
'np.array_float': np.array([1.0, 2.0]),
'np.array_2d': np.zeros((3, 4)),
})
if pd:
_supported.update({'pd.Timestamp': pd.Timestamp(_now)})
_supported.update({
'pd.Timestamp': pd.Timestamp(_now),
'pd.Series': pd.Series([1, 2, 3]),
'pd.DataFrame': pd.DataFrame({'a': [1, 2], 'b': [3, 4]}),
})

@pytest.mark.parametrize('obj', _supported.values(), ids=_supported.keys())
def test_comparator_equal(obj):
assert Comparator.is_equal(obj, obj)


# ---- pathlib tests ----

def test_path_equal():
assert Comparator.is_equal(pathlib.Path('/a/b'), pathlib.Path('/a/b'))

def test_path_not_equal():
assert not Comparator.is_equal(pathlib.Path('/a/b'), pathlib.Path('/a/c'))

def test_purepath_equal():
assert Comparator.is_equal(pathlib.PurePosixPath('/x'), pathlib.PurePosixPath('/x'))


# ---- numpy tests ----

@pytest.mark.skipif(np is None, reason='numpy not available')
class TestComparatorNumpy:

def test_array_equal(self):
a = np.array([1, 2, 3])
b = np.array([1, 2, 3])
assert Comparator.is_equal(a, b)

def test_array_not_equal(self):
a = np.array([1, 2, 3])
b = np.array([1, 2, 4])
assert not Comparator.is_equal(a, b)

def test_array_different_shape(self):
a = np.array([1, 2, 3])
b = np.array([[1, 2, 3]])
assert not Comparator.is_equal(a, b)

def test_array_different_dtype(self):
a = np.array([1, 2], dtype=np.int32)
b = np.array([1, 2], dtype=np.float64)
assert not Comparator.is_equal(a, b)

def test_array_identity(self):
a = np.array([1, 2, 3])
assert Comparator.is_equal(a, a)

def test_array_large_skips(self):
"""Arrays larger than array_max_size should return False."""
old = Comparator.array_max_size
try:
Comparator.array_max_size = 5
a = np.arange(10)
b = np.arange(10)
assert not Comparator.is_equal(a, b)
finally:
Comparator.array_max_size = old

def test_array_with_nan(self):
a = np.array([1.0, float('nan'), 3.0])
b = np.array([1.0, float('nan'), 3.0])
# np.array_equal treats NaN == NaN as False
assert not Comparator.is_equal(a, b)


# ---- pandas tests ----

@pytest.mark.skipif(pd is None, reason='pandas not available')
class TestComparatorPandas:

def test_series_equal(self):
a = pd.Series([1, 2, 3])
b = pd.Series([1, 2, 3])
assert Comparator.is_equal(a, b)

def test_series_not_equal(self):
a = pd.Series([1, 2, 3])
b = pd.Series([1, 2, 4])
assert not Comparator.is_equal(a, b)

def test_dataframe_equal(self):
a = pd.DataFrame({'x': [1, 2], 'y': [3, 4]})
b = pd.DataFrame({'x': [1, 2], 'y': [3, 4]})
assert Comparator.is_equal(a, b)

def test_dataframe_not_equal(self):
a = pd.DataFrame({'x': [1, 2]})
b = pd.DataFrame({'x': [1, 3]})
assert not Comparator.is_equal(a, b)

def test_dataframe_different_shape(self):
a = pd.DataFrame({'x': [1, 2]})
b = pd.DataFrame({'x': [1, 2], 'y': [3, 4]})
assert not Comparator.is_equal(a, b)

def test_dataframe_with_nan(self):
a = pd.DataFrame({'x': [1.0, float('nan')]})
b = pd.DataFrame({'x': [1.0, float('nan')]})
# pd.DataFrame.equals treats NaN as equal
assert Comparator.is_equal(a, b)

def test_series_large_skips(self):
"""Series larger than array_max_size should return False."""
old = Comparator.array_max_size
try:
Comparator.array_max_size = 5
a = pd.Series(range(10))
b = pd.Series(range(10))
assert not Comparator.is_equal(a, b)
finally:
Comparator.array_max_size = old
Comment thread
philippjfr marked this conversation as resolved.

def test_dataframe_large_skips(self):
"""DataFrames larger than array_max_size should return False."""
old = Comparator.array_max_size
try:
Comparator.array_max_size = 5
a = pd.DataFrame({'x': range(10)})
b = pd.DataFrame({'x': range(10)})
assert not Comparator.is_equal(a, b)
finally:
Comparator.array_max_size = old