diff --git a/CHANGES.rst b/CHANGES.rst index db21d87ec..e4ee84e02 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -40,6 +40,10 @@ New Features Additionally, negative numbers indicated with parentheses can be converted to the regular numeric format (``(432)`` becomes ``-432``). :pr:`1772` by :user:`Gabriela Gómez Jiménez `. +- :meth:`TableReport.json` now includes histogram data for numeric and datetime + columns (the bin count and edges, and numbers of low and high outliers). Now + ``json()`` contains all the information shown in the report html rendering, + including the plots. :pr:`2164` by :user:`Jérôme Dockès `. Changes ------- diff --git a/skrub/_reporting/_data/templates/column-summary.html b/skrub/_reporting/_data/templates/column-summary.html index 7c658a092..e7482b5d0 100644 --- a/skrub/_reporting/_data/templates/column-summary.html +++ b/skrub/_reporting/_data/templates/column-summary.html @@ -4,10 +4,10 @@ data-name-repr="{{ column.name.__repr__() }}" data-column-name="{{ column.name }}" data-column-idx="{{ column.idx }}" - {% if column['n_low_outliers'] %} + {% if column['histogram_data'] and column['histogram_data']['n_low_outliers'] %} data-has-low-outliers {% endif %} - {% if column['n_high_outliers'] %} + {% if column['histogram_data'] and column['histogram_data']['n_high_outliers'] %} data-has-high-outliers {% endif %} data-manager="FilterableColumn {% if in_sample_tab %}SampleColumnSummary{% endif %}" diff --git a/skrub/_reporting/_plotting.py b/skrub/_reporting/_plotting.py index 131be0113..bd4c512e4 100644 --- a/skrub/_reporting/_plotting.py +++ b/skrub/_reporting/_plotting.py @@ -11,8 +11,8 @@ import numpy as np from matplotlib import pyplot as plt -from skrub import _dataframe as sbd - +from .. import _dataframe as sbd +from .. import _datetime_encoder from . import _utils __all__ = ["COLORS", "COLOR_0", "histogram", "line", "value_counts"] @@ -192,15 +192,36 @@ def _get_range(values, frac=0.2, factor=3.0): return low, high -def _robust_hist(values, ax, color): +def _robust_hist(col, ax=None, color=None): + col = sbd.drop_nulls(col) + if sbd.is_float(col): + # avoid any issues with pandas nullable dtypes + # (to_numpy can yield a numpy array with object dtype in old pandas + # version if there are inf or nan) + col = sbd.to_float32(col) + values = sbd.to_numpy(col) + if sbd.is_any_date(col): + # numpy histogram does not handle datetimes but matplotlib does, so we + # convert to the total number of seconds since epoch (a float) + np_histogram_values = sbd.to_numpy( + _datetime_encoder.DatetimeEncoder(resolution=None).fit_transform(col) + ).ravel() + else: + np_histogram_values = values low, high = _get_range(values) - inliers = values[(low <= values) & (values <= high)] + inlier_mask = (low <= values) & (values <= high) n_low_outliers = (values < low).sum() n_high_outliers = (high < values).sum() - n, bins, patches = ax.hist(inliers) + result = {"n_low_outliers": n_low_outliers, "n_high_outliers": n_high_outliers} + result["bin_counts"], result["bin_edges"] = np.histogram( + np_histogram_values[inlier_mask] + ) + if ax is None: + return result + n, bins, patches = ax.hist(values[inlier_mask]) n_out = n_low_outliers + n_high_outliers if not n_out: - return 0, 0 + return result width = bins[1] - bins[0] start, stop = bins[0], bins[-1] line_params = dict(color=_RED, linestyle="--", ymax=0.95) @@ -229,28 +250,25 @@ def _robust_hist(values, ax, color): color=_RED, ) ax.set_xlim(start, stop) - return n_low_outliers, n_high_outliers + return result + + +def histogram_data(col): + return _robust_hist(col, ax=None, color=None) @_plot def histogram(col, duration_unit=None, color=COLOR_0): """Histogram for a numeric column.""" - col = sbd.drop_nulls(col) - if sbd.is_float(col): - # avoid any issues with pandas nullable dtypes - # (to_numpy can yield a numpy array with object dtype in old pandas - # version if there are inf or nan) - col = sbd.to_float32(col) - values = sbd.to_numpy(col) fig, ax = plt.subplots() _despine(ax) - n_low_outliers, n_high_outliers = _robust_hist(values, ax, color=color) + histogram_data = _robust_hist(col, ax=ax, color=color) if duration_unit is not None: ax.set_xlabel(f"{duration_unit.capitalize()}s") if sbd.is_any_date(col): _rotate_ticklabels(ax) _adjust_fig_size(fig, ax, 2.0, 1.0) - return _serialize(fig), n_low_outliers, n_high_outliers + return _serialize(fig), histogram_data @_plot diff --git a/skrub/_reporting/_summarize.py b/skrub/_reporting/_summarize.py index b544986e2..e582e07fd 100644 --- a/skrub/_reporting/_summarize.py +++ b/skrub/_reporting/_summarize.py @@ -255,9 +255,12 @@ def _add_datetime_summary(summary, column, with_plots): if with_plots: ( summary["histogram_plot"], - summary["n_low_outliers"], - summary["n_high_outliers"], + summary["histogram_data"], ) = _plotting.histogram(column, color=_plotting.COLORS[0]) + else: + # besides the plots, the bin counts and edges are always stored and + # available in the json output. + summary["histogram_data"] = _plotting.histogram_data(column) def _add_numeric_summary( @@ -289,13 +292,12 @@ def _add_numeric_summary( summary["value_is_constant"] = False summary["quantiles"] = quantiles if not with_plots: + # besides the plots, the bin counts and edges are always stored and + # available in the json output. + summary["histogram_data"] = _plotting.histogram_data(column) return if order_by_column is None: - ( - summary["histogram_plot"], - summary["n_low_outliers"], - summary["n_high_outliers"], - ) = _plotting.histogram( + summary["histogram_plot"], summary["histogram_data"] = _plotting.histogram( column, duration_unit=duration_unit, color=_plotting.COLORS[0] ) else: diff --git a/skrub/_reporting/_utils.py b/skrub/_reporting/_utils.py index e6dd0b44d..6238f07d8 100644 --- a/skrub/_reporting/_utils.py +++ b/skrub/_reporting/_utils.py @@ -115,6 +115,8 @@ def default(self, value): return int(value) if isinstance(value, np.floating): return float(value) + if isinstance(value, np.ndarray): + return value.tolist() raise diff --git a/skrub/_reporting/tests/test_plotting.py b/skrub/_reporting/tests/test_plotting.py index c49eb98aa..601f96526 100644 --- a/skrub/_reporting/tests/test_plotting.py +++ b/skrub/_reporting/tests/test_plotting.py @@ -10,21 +10,21 @@ def test_histogram(): o = rng.uniform(-100, 100, size=10) data = pd.Series(np.concatenate([x, o])) - _, n_low, n_high = _plotting.histogram(data) - assert (n_low, n_high) == (5, 4) + _, hist = _plotting.histogram(data) + assert (hist["n_low_outliers"], hist["n_high_outliers"]) == (5, 4) data = pd.Series(np.concatenate([x, o - 1000])) - _, n_low, n_high = _plotting.histogram(data) - assert (n_low, n_high) == (10, 0) + _, hist = _plotting.histogram(data) + assert (hist["n_low_outliers"], hist["n_high_outliers"]) == (10, 0) data = pd.Series(np.concatenate([x, o + 1000])) - _, n_low, n_high = _plotting.histogram(data) - assert (n_low, n_high) == (0, 10) + _, hist = _plotting.histogram(data) + assert (hist["n_low_outliers"], hist["n_high_outliers"]) == (0, 10) data = pd.Series(x) - _, n_low, n_high = _plotting.histogram(data) - assert (n_low, n_high) == (0, 0) + _, hist = _plotting.histogram(data) + assert (hist["n_low_outliers"], hist["n_high_outliers"]) == (0, 0) data = pd.Series([0.0]) - _, n_low, n_high = _plotting.histogram(data) - assert (n_low, n_high) == (0, 0) + _, hist = _plotting.histogram(data) + assert (hist["n_low_outliers"], hist["n_high_outliers"]) == (0, 0) diff --git a/skrub/_reporting/tests/test_summarize.py b/skrub/_reporting/tests/test_summarize.py index 245465239..bde2c4df9 100644 --- a/skrub/_reporting/tests/test_summarize.py +++ b/skrub/_reporting/tests/test_summarize.py @@ -84,6 +84,9 @@ def test_summarize( 0.75: 33.6, 1.0: 78.3, } + if order_by is None: + assert len(summary["columns"][5]["histogram_data"]["bin_counts"]) == 10 + assert len(summary["columns"][5]["histogram_data"]["bin_edges"]) == 11 assert summary["columns"][7]["null_count"] == 9 assert summary["columns"][7]["nulls_level"] == "warning" assert summary["columns"][8]["null_count"] == 17 diff --git a/skrub/_reporting/tests/test_utils.py b/skrub/_reporting/tests/test_utils.py index c2f6b7a68..23dcf571e 100644 --- a/skrub/_reporting/tests/test_utils.py +++ b/skrub/_reporting/tests/test_utils.py @@ -104,7 +104,7 @@ def test_json_encoder(): d = {"a": x[0], "b": y[0]} assert json.dumps(d, cls=_utils.JSONEncoder) == '{"a": 1, "b": 1.0}' with pytest.raises(TypeError, match=".*JSON serializable"): - json.dumps({"a": x}, cls=_utils.JSONEncoder) + json.dumps({"a": np}, cls=_utils.JSONEncoder) def test_svg_to_img_src():