Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ New Features
some more attributes for inspection by scikit-learn: ``__sklearn_tags__``,
``classes_``, ``_estimator_type``. :pr:`1931` by :user:`Jérôme Dockès
<jeromedockes>`.
- :class:`TableReport` now has a `n_jobs` parameter enabling parallel
computing to be used in the processing of individual columns. :pr:`1949` by
:user:`Eloi Massoulié <emassoulie>`.

Changes
-------
Expand Down
33 changes: 17 additions & 16 deletions skrub/_reporting/_summarize.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import sys

from joblib import Parallel, delayed

from .. import _column_associations, _config
from .. import _dataframe as sbd
from . import _plotting, _sample_table, _utils
Expand All @@ -28,6 +30,7 @@ def summarize_dataframe(
max_top_slice_size=5,
max_bottom_slice_size=5,
verbose=1,
n_jobs=None,
):
"""Collect information about a dataframe, used to produce reports.

Expand Down Expand Up @@ -87,6 +90,7 @@ def summarize_dataframe(
max_bottom_slice_size=max_bottom_slice_size,
),
}
columns = []
if title is not None:
summary["title"] = title
if order_by is not None:
Expand All @@ -97,23 +101,20 @@ def summarize_dataframe(
else:
order_by_idx = sbd.column_names(df).index(order_by)
order_by_column = sbd.col_by_idx(df, order_by_idx)
for position in range(sbd.shape(df)[1]):
if verbose > 0:
print(
f"Processing column {position + 1: >3} / {n_columns}",
file=sys.stderr,
end="\r",
flush=True,
)
summary["columns"].append(
_summarize_column(
sbd.col_by_idx(df, position),
position,
dataframe_summary=summary,
with_plots=with_plots,
order_by_column=order_by_column,
)

columns = Parallel(n_jobs=n_jobs, verbose=verbose, backend="loky")(
delayed(_summarize_column)(
sbd.col_by_idx(df, position),
position,
dataframe_summary={"n_rows": summary["n_rows"]},
with_plots=with_plots,
order_by_column=order_by_column,
)
for position in range(sbd.shape(df)[1])
)

summary["columns"] = columns

if verbose > 0:
print(flush=True, file=sys.stderr)

Expand Down
3 changes: 3 additions & 0 deletions skrub/_reporting/_table_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ def __init__(
max_plot_columns=None,
max_association_columns=None,
open_tab="table",
n_jobs=None,
):
if isinstance(dataframe, np.ndarray):
if dataframe.ndim == 1:
Expand Down Expand Up @@ -243,6 +244,7 @@ def __init__(
sbd.to_frame(dataframe) if sbd.is_column(dataframe) else dataframe
)
self.n_columns = sbd.shape(self.dataframe)[1]
self.n_jobs = n_jobs

def _set_minimal_mode(self):
"""Put the report in minimal mode.
Expand Down Expand Up @@ -289,6 +291,7 @@ def _summary(self):
with_plots=with_plots,
with_associations=with_associations,
title=self.title,
n_jobs=self.n_jobs,
**self._summary_kwargs,
)

Expand Down
29 changes: 29 additions & 0 deletions skrub/_reporting/tests/test_table_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
import warnings
from pathlib import Path

import joblib
import numpy as np
import pytest
from sklearn.utils import Bunch
from sklearn.utils._testing import skip_if_no_parallel

from skrub import TableReport, ToDatetime
from skrub import _dataframe as sbd
Expand Down Expand Up @@ -389,6 +391,33 @@ def test_numpy_array_columns(input_array, expected_columns):
assert report._summary["n_columns"] == expected_columns


@skip_if_no_parallel
def test_parallelism(df_module):
df = df_module.make_dataframe(
dict(
a=[1, 2, 3, 4],
b=["one", "two", "three", "four"],
c=[11.1, 11.2, 11.3, 11.4],
)
)

report = TableReport(df, verbose=0)
columns = report._summary["columns"]

with joblib.parallel_backend("loky"):
for n_jobs in [None, 2, -1]:
parallel_report = TableReport(df, n_jobs=n_jobs, verbose=0)
parallel_columns = parallel_report._summary["columns"]

assert len(columns) == len(parallel_columns)
for i in range(len(columns)):
assert columns[i].keys() == parallel_columns[i].keys()
assert columns[i]["name"] == parallel_columns[i]["name"]
assert columns[i]["dtype"] == parallel_columns[i]["dtype"]

assert parallel_report.n_jobs == n_jobs


def _pyarrow_available():
try:
import pyarrow # noqa: F401
Expand Down
Loading