diff --git a/CHANGES.rst b/CHANGES.rst index 40c6d1230..def453bff 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -54,6 +54,9 @@ Changes :pr:`2096` by :user:`Ayesha Siddiqua `. - The :class:`TableReport` can now be exported in markdown format with ``.markdown``. :pr:`2048` by :user:`Riccardo Cappuzzo `. +- The :class:`DropUninformative` was improved so that `drop_if_constant` becomes a variance + threshold and it acts similarly to the VarianceThreshold transformer. + :pr:`2155` by :user:`Janne de Melo Santana `, :user:`Xixi Khamsane`, :user:`Rim El Khader` Bugfixes -------- diff --git a/skrub/_drop_uninformative.py b/skrub/_drop_uninformative.py index e20e07e93..21119831a 100644 --- a/skrub/_drop_uninformative.py +++ b/skrub/_drop_uninformative.py @@ -85,10 +85,12 @@ def __init__( drop_if_constant=False, drop_if_unique=False, drop_null_fraction=1.0, + threshold=0.0, ): self.drop_if_constant = drop_if_constant self.drop_if_unique = drop_if_unique self.drop_null_fraction = drop_null_fraction + self.threshold = threshold def _check_params(self): if not isinstance(self.drop_if_constant, bool): @@ -127,8 +129,19 @@ def _drop_if_too_many_nulls(self, column): def _drop_if_constant(self, column): if self.drop_if_constant: - if (sbd.n_unique(column) == 1) and (self._null_count == 0): + if sbd.is_numeric(column) == 1 and ( + self._null_count == 0 + ): # if numeric or boolean + if ( + sbd.std(column) ** 2 <= self.threshold + ): # check if passes the threshold + return True + else: + return False + elif (sbd.n_unique(column) == 1) and (self._null_count == 0): + # use the original logic to deal with the other cases return True + return False def _drop_if_unique(self, column): diff --git a/skrub/tests/test_drop_uninformative.py b/skrub/tests/test_drop_uninformative.py index e3f04cacd..e2f0b062b 100644 --- a/skrub/tests/test_drop_uninformative.py +++ b/skrub/tests/test_drop_uninformative.py @@ -130,6 +130,7 @@ def drop_if_constant_table(df_module): "const", None, ], + "low_variance": [0.01, 0.02, 0.05], } ) @@ -141,6 +142,7 @@ def drop_if_constant_table(df_module): (dict(drop_if_constant=True), "constant_float", []), (dict(drop_if_constant=True), "constant_float_with_nulls", [2.5, 2.5, np.nan]), (dict(drop_if_constant=True), "constant_str", []), + (dict(drop_if_constant=True, threshold=0.5), "low_variance", []), ( dict(drop_if_constant=True), "constant_str_with_nulls",