From e92cd1312dc63a9a3356b3a19cc803aa087367ab Mon Sep 17 00:00:00 2001 From: jtleider Date: Sat, 4 Aug 2018 15:28:18 -0500 Subject: [PATCH 1/4] Support multilevel groupby --- tableone.py | 79 +++++++++++++++++++++++++++--------------------- test_tableone.py | 75 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 120 insertions(+), 34 deletions(-) diff --git a/tableone.py b/tableone.py index 9b75a03..aa0a252 100644 --- a/tableone.py +++ b/tableone.py @@ -33,8 +33,8 @@ class TableOne(object): List of columns in the dataset to be included in the final table. categorical : list, optional List of columns that contain categorical variables. - groupby : str, optional - Optional column for stratifying the final table (default: None). + groupby : list, optional + Optional columns for stratifying the final table (default: None). nonnormal : list, optional List of columns that contain non-normal variables (default: None). pval : bool, optional @@ -83,9 +83,9 @@ def __init__(self, data, columns=None, categorical=None, groupby=None, # check input arguments if not groupby: - groupby = '' - elif groupby and type(groupby) == list: - groupby = groupby[0] + groupby = [] + elif groupby and isinstance(groupby, str): + groupby = [groupby] if not nonnormal: nonnormal=[] @@ -115,7 +115,7 @@ def __init__(self, data, columns=None, categorical=None, groupby=None, self._columns = list(columns) self._isnull = isnull - self._continuous = [c for c in columns if c not in categorical + [groupby]] + self._continuous = [c for c in columns if c not in categorical + groupby] self._categorical = categorical self._nonnormal = nonnormal self._pval = pval @@ -131,13 +131,13 @@ def __init__(self, data, columns=None, categorical=None, groupby=None, # output column names that cannot be contained in a groupby self._reserved_columns = ['isnull', 'pval', 'ptest', 'pval (adjusted)'] if self._groupby: - self._groupbylvls = sorted(data.groupby(groupby).groups.keys()) + self._groups = data.groupby(groupby).groups # check that the group levels do not include reserved words - for level in self._groupbylvls: + for level in data.groupby(self._groupby[0]).groups: if level in self._reserved_columns: raise InputError('Group level contained "{}", a reserved keyword for tableone.'.format(level)) else: - self._groupbylvls = ['overall'] + self._groups = {'overall': data.index} # forgive me jraffa if self._pval: @@ -380,13 +380,15 @@ def _create_cont_describe(self,data): if self._groupby: # add the groupby column back - cont_data = cont_data.merge(data[[self._groupby]], + cont_data = cont_data.merge(data[self._groupby], left_index=True, right_index=True) # group and aggregate data df_cont = pd.pivot_table(cont_data, - columns=[self._groupby], + columns=self._groupby, aggfunc=aggfuncs) + if len(self._groupby) > 1: + df_cont = df_cont.unstack([i+1 for i in range(len(self._groupby))]) else: # if no groupby, just add single group column df_cont = cont_data.apply(aggfuncs).T @@ -419,11 +421,8 @@ def _create_cat_describe(self,data): """ group_dict = {} - for g in self._groupbylvls: - if self._groupby: - d_slice = data.loc[data[self._groupby] == g, self._categorical] - else: - d_slice = data[self._categorical].copy() + for g, gdata in self._groups.items(): + d_slice = data.loc[gdata, self._categorical].copy() # create a dataframe with freq, proportion df = d_slice.copy() @@ -456,9 +455,9 @@ def _create_cat_describe(self,data): group_dict[g] = df df_cat = pd.concat(group_dict,axis=1) - # ensure the groups are the 2nd level of the column index + # ensure the groups are the final levels of the column index if df_cat.columns.nlevels>1: - df_cat = df_cat.swaplevel(0, 1, axis=1).sort_index(axis=1,level=0) + df_cat = df_cat.swaplevel(0, -1, axis=1).sort_index(axis=1,level=0) return df_cat @@ -495,8 +494,8 @@ def _create_significance_table(self,data): if is_continuous: catlevels = None grouped_data = [] - for s in self._groupbylvls: - lvl_data = data.loc[data[self._groupby]==s, v] + for g, gdata in self._groups.items(): + lvl_data = data.loc[gdata, v] # coerce to numeric and drop non-numeric data lvl_data = lvl_data.apply(pd.to_numeric, errors='coerce').dropna() # append to overall group data @@ -505,7 +504,7 @@ def _create_significance_table(self,data): # if categorical, create contingency table elif is_categorical: catlevels = sorted(data[v].astype('category').cat.categories) - grouped_data = pd.crosstab(data[self._groupby].rename('_groupby_var_'),data[v]) + grouped_data = pd.crosstab([data[g] for g in self._groupby],data[v]) min_observed = grouped_data.sum(axis=1).min() # minimum number of observations across all levels @@ -516,6 +515,7 @@ def _create_significance_table(self,data): grouped_data,is_continuous,is_categorical, is_normal,min_observed,catlevels) + if len(self._groupby) > 1: df.columns = pd.MultiIndex.from_product([df.columns if i == 0 else [''] for i in range(len(self._groupby))]) return df def _p_test(self,v,grouped_data,is_continuous,is_categorical, @@ -600,11 +600,14 @@ def _create_cont_table(self,data): table.columns = table.columns.droplevel(level=0) # add a column of null counts as 1-count() from previous function + # isnull needs to be its own column nulltable = data[self._continuous].isnull().sum().to_frame(name='isnull') + if len(self._groupby) > 1: nulltable.columns = pd.MultiIndex.from_product([['isnull'] if i == 0 else [''] for i in range(len(self._groupby))]) try: table = table.join(nulltable) except TypeError: # if columns form a CategoricalIndex, need to convert to string first - table.columns = table.columns.astype(str) + if len(self._groupby) > 1: table.columns = pd.MultiIndex.from_tuples([tuple(str(col_value) for col_value in col) for col in table.columns]) + else: table.columns = table.columns.astype(str) table = table.join(nulltable) # add an empty level column, for joining with cat table @@ -632,10 +635,12 @@ def _create_cat_table(self,data): # add the total count of null values across all levels isnull = data[self._categorical].isnull().sum().to_frame(name='isnull') isnull.index.rename('variable', inplace=True) + if len(self._groupby) > 1: isnull.columns = pd.MultiIndex.from_product([['isnull'] if i == 0 else [''] for i in range(len(self._groupby))]) try: table = table.join(isnull) except TypeError: # if columns form a CategoricalIndex, need to convert to string first - table.columns = table.columns.astype(str) + if len(self._groupby) > 1: table.columns = pd.MultiIndex.from_tuples([tuple(str(col_value) for col_value in col) for col in table.columns]) + else: table.columns = table.columns.astype(str) table = table.join(isnull) # add pval column @@ -699,13 +704,15 @@ def _create_tableone(self,data): n_row = pd.DataFrame(columns = ['variable','level','isnull']) n_row.set_index(['variable','level'], inplace=True) n_row.loc['n', ''] = None + if len(self._groupby) > 1: n_row.columns = pd.MultiIndex.from_tuples( + [tuple('isnull' if i == 0 else '' for i in range(len(self._groupby))), tuple('' for i in range(len(self._groupby)))], names=table.columns.names) table = pd.concat([n_row,table],sort=False) - if self._groupbylvls == ['overall']: + if not self._groupby: table.loc['n','overall'] = len(data.index) else: - for g in self._groupbylvls: - ct = data[self._groupby][data[self._groupby]==g].count() + for g, gdata in self._groups.items(): + ct = len(gdata) table.loc['n',g] = ct # only display data in first level row @@ -716,28 +723,32 @@ def _create_tableone(self,data): if col in table.columns.values: dupe_columns.append(col) + if len(self._groupby) > 1: dupe_columns = [tuple(c if i == 0 else '' for i in range(len(self._groupby))) for c in dupe_columns] table[dupe_columns] = table[dupe_columns].mask(dupe_mask).fillna('') # remove empty column added above - table.drop([''], axis=1, inplace=True) + if len(self._groupby) > 1: table.drop(tuple('' for i in range(len(self._groupby))), axis=1, inplace=True) + else: table.drop('', axis=1, inplace=True) # remove isnull column if not needed if not self._isnull: - table.drop('isnull',axis=1,inplace=True) + if len(self._groupby) > 1: table.drop(tuple('isnull' if i == 0 else '' for i in range(len(self._groupby))),axis=1,inplace=True) + else: table.drop('isnull',axis=1,inplace=True) # replace nans with empty strings table.fillna('',inplace=True) # add column index - if not self._groupbylvls == ['overall']: + if self._groupby: # rename groupby variable if requested - c = self._groupby - if self._alt_labels: - if self._groupby in self._alt_labels: - c = self._alt_labels[self._groupby] + if self._alt_labels: c = ', '.join([self._alt_labels.get(g, g) for g in self._groupby]) + else: c = ', '.join(self._groupby) c = 'Grouped by {}'.format(c) - table.columns = pd.MultiIndex.from_product([[c], table.columns]) + if len(self._groupby) > 1: + table.columns = pd.MultiIndex.from_tuples(tuple([c])+table.columns[i] for i in range(len(table.columns))) + else: + table.columns = pd.MultiIndex.from_product([[c], table.columns]) # display alternative labels if assigned table.rename(index=self._create_row_labels(), inplace=True, level=0) diff --git a/test_tableone.py b/test_tableone.py index 5041d3c..0330ad5 100644 --- a/test_tableone.py +++ b/test_tableone.py @@ -525,3 +525,78 @@ def test_check_null_counts_are_correct_pn(self): # check each null count is correct col = isnull.index[i][0] assert self.data_pn[col].isnull().sum() == v + + @with_setup(setup, teardown) + def test_multilevel_groupby(self): + """ + Test multilevel groupby produces expected results + """ + columns = ['Age', 'Height', 'Weight', 'ICU'] + categorical = ['ICU'] + + table = TableOne(self.data_pn, columns=columns, categorical=categorical, groupby=['death', 'MechVent']) + assert table.tableone.columns[0][0] == 'Grouped by death, MechVent' + table.tableone.columns = table.tableone.columns.droplevel(0) + assert len(table.tableone.columns) == 5 + for i, correct_col in enumerate([('isnull', ''), (0, 0), (0, 1), (1, 0), (1, 1)]): + assert table.tableone.columns[i] == correct_col + assert len(table.tableone.index) == 8 + rows = [('n', ''), ('Age', ''), ('Height', ''), ('Weight', ''), ('ICU', 'CCU'), ('ICU', 'CSRU'), ('ICU', 'MICU'), ('ICU', 'SICU')] + for i, correct_row in enumerate(rows): + assert table.tableone.index[i] == correct_row + correct_value = { + ('n', ''): ['', 468, 396, 72, 64], + ('Age', ''): [0, '65.29 (17.94)', '62.47 (16.65)', '71.06 (13.90)', '72.42 (14.21)'], + ('Height', ''): [475, '171.55 (31.78)', '169.24 (11.06)', '167.36 (11.32)', '169.86 (11.34)'], + ('Weight', ''): [302, '81.03 (22.28)', '85.02 (24.67)', '83.89 (28.35)', '80.44 (21.66)'], + ('ICU', 'CCU'): [0, '110 (23.5)', '11 (15.28)', '27 (6.82)', '14 (21.88)'], + ('ICU', 'CSRU'): ['', '50 (10.68)', '3 (4.17)', '144 (36.36)', '5 (7.81)'], + ('ICU', 'MICU'): ['', '205 (43.8)', '47 (65.28)', '113 (28.54)', '15 (23.44)'], + ('ICU', 'SICU'): ['', '103 (22.01)', '11 (15.28)', '112 (28.28)', '30 (46.88)'] + } + for row in rows: + assert list(table.tableone.loc[row]) == correct_value[row] + + @with_setup(setup, teardown) + def test_multilevel_groupby_pval(self): + """ + Test multilevel groupby works when p-values are requested + """ + columns = ['Age', 'Height', 'Weight', 'ICU'] + categorical = ['ICU'] + + table = TableOne(self.data_pn, columns=columns, categorical=categorical, groupby=['death', 'MechVent'], pval=True) + table = TableOne(self.data_pn, columns=columns, categorical=categorical, groupby=['death', 'MechVent'], pval=True, pval_adjust='bonferroni') + table = TableOne(self.data_pn, columns=columns, categorical=categorical, groupby=['death', 'MechVent'], pval=True, nonnormal=['Age']) + assert table.tableone.loc['Weight', ('Grouped by death, MechVent', 'pval', '')][0] == '0.187' + + @with_setup(setup, teardown) + def test_multilevel_groupby_noisnull(self): + """ + Test multilevel groupby runs without error when isnull option is False + """ + columns = ['Age', 'Height', 'Weight', 'ICU'] + categorical = ['ICU'] + + table = TableOne(self.data_pn, columns=columns, categorical=categorical, groupby=['death', 'MechVent'], isnull=False) + + @with_setup(setup, teardown) + def test_multilevel_groupby_sort(self): + """ + Test multilevel groupby runs without error when sort option is True + """ + columns = ['Age', 'Height', 'Weight', 'ICU'] + categorical = ['ICU'] + + table = TableOne(self.data_pn, columns=columns, categorical=categorical, groupby=['death', 'MechVent'], sort=True) + + @with_setup(setup, teardown) + def test_multilevel_groupby_limit(self): + """ + Test multilevel groupby runs correctly when limit option is set + """ + columns = ['Age', 'Height', 'Weight', 'ICU'] + categorical = ['ICU'] + + table = TableOne(self.data_pn, columns=columns, categorical=categorical, groupby=['death', 'MechVent'], limit=2) + assert list(table.tableone.loc['ICU'].index) == ['MICU', 'SICU'] From 9b71fb5f0cc99886b418baf11982fe7a7e153535 Mon Sep 17 00:00:00 2001 From: jtleider Date: Sat, 4 Aug 2018 16:10:54 -0500 Subject: [PATCH 2/4] Add test that groupby with categorical variable runs without error --- test_tableone.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/test_tableone.py b/test_tableone.py index 0330ad5..ce1aef1 100644 --- a/test_tableone.py +++ b/test_tableone.py @@ -600,3 +600,17 @@ def test_multilevel_groupby_limit(self): table = TableOne(self.data_pn, columns=columns, categorical=categorical, groupby=['death', 'MechVent'], limit=2) assert list(table.tableone.loc['ICU'].index) == ['MICU', 'SICU'] + + @with_setup(setup, teardown) + def test_groupby_categorical(self): + """ + Test groupby runs without error with categorical groupby variable + """ + columns = ['Age', 'Height', 'Weight', 'ICU'] + categorical = ['ICU'] + + pn = self.data_pn.copy() + pn['death'] = pn['death'].astype('category') + table = TableOne(pn, columns=columns, categorical=categorical, groupby=['death']) + table = TableOne(pn, columns=columns, categorical=categorical, groupby=['death', 'MechVent']) + From c4ef642679ecdadac4bc81d4e1b181598d2e78f7 Mon Sep 17 00:00:00 2001 From: jtleider Date: Sat, 4 Aug 2018 16:48:15 -0500 Subject: [PATCH 3/4] Fix issue where descriptives were misaligned with categorical groupby variable --- tableone.py | 16 ++++------------ test_tableone.py | 4 +++- 2 files changed, 7 insertions(+), 13 deletions(-) diff --git a/tableone.py b/tableone.py index aa0a252..2a7a198 100644 --- a/tableone.py +++ b/tableone.py @@ -131,6 +131,8 @@ def __init__(self, data, columns=None, categorical=None, groupby=None, # output column names that cannot be contained in a groupby self._reserved_columns = ['isnull', 'pval', 'ptest', 'pval (adjusted)'] if self._groupby: + for groupbyvar in groupby: + data[groupbyvar] = data[groupbyvar].astype(str) # Treat groupby variables as string to avoid problems with categorical groupby self._groups = data.groupby(groupby).groups # check that the group levels do not include reserved words for level in data.groupby(self._groupby[0]).groups: @@ -603,12 +605,7 @@ def _create_cont_table(self,data): # isnull needs to be its own column nulltable = data[self._continuous].isnull().sum().to_frame(name='isnull') if len(self._groupby) > 1: nulltable.columns = pd.MultiIndex.from_product([['isnull'] if i == 0 else [''] for i in range(len(self._groupby))]) - try: - table = table.join(nulltable) - except TypeError: # if columns form a CategoricalIndex, need to convert to string first - if len(self._groupby) > 1: table.columns = pd.MultiIndex.from_tuples([tuple(str(col_value) for col_value in col) for col in table.columns]) - else: table.columns = table.columns.astype(str) - table = table.join(nulltable) + table = table.join(nulltable) # add an empty level column, for joining with cat table table['level'] = '' @@ -636,12 +633,7 @@ def _create_cat_table(self,data): isnull = data[self._categorical].isnull().sum().to_frame(name='isnull') isnull.index.rename('variable', inplace=True) if len(self._groupby) > 1: isnull.columns = pd.MultiIndex.from_product([['isnull'] if i == 0 else [''] for i in range(len(self._groupby))]) - try: - table = table.join(isnull) - except TypeError: # if columns form a CategoricalIndex, need to convert to string first - if len(self._groupby) > 1: table.columns = pd.MultiIndex.from_tuples([tuple(str(col_value) for col_value in col) for col in table.columns]) - else: table.columns = table.columns.astype(str) - table = table.join(isnull) + table = table.join(isnull) # add pval column if self._pval and self._pval_adjust: diff --git a/test_tableone.py b/test_tableone.py index ce1aef1..f40d011 100644 --- a/test_tableone.py +++ b/test_tableone.py @@ -538,7 +538,7 @@ def test_multilevel_groupby(self): assert table.tableone.columns[0][0] == 'Grouped by death, MechVent' table.tableone.columns = table.tableone.columns.droplevel(0) assert len(table.tableone.columns) == 5 - for i, correct_col in enumerate([('isnull', ''), (0, 0), (0, 1), (1, 0), (1, 1)]): + for i, correct_col in enumerate([('isnull', ''), ('0', '0'), ('0', '1'), ('1', '0'), ('1', '1')]): assert table.tableone.columns[i] == correct_col assert len(table.tableone.index) == 8 rows = [('n', ''), ('Age', ''), ('Height', ''), ('Weight', ''), ('ICU', 'CCU'), ('ICU', 'CSRU'), ('ICU', 'MICU'), ('ICU', 'SICU')] @@ -612,5 +612,7 @@ def test_groupby_categorical(self): pn = self.data_pn.copy() pn['death'] = pn['death'].astype('category') table = TableOne(pn, columns=columns, categorical=categorical, groupby=['death']) + assert len(table.tableone.columns == 3) table = TableOne(pn, columns=columns, categorical=categorical, groupby=['death', 'MechVent']) + assert len(table.tableone.columns) == 5 From 6de33b5342ffd5d529bfaf7a3de3a38afe1e3ac4 Mon Sep 17 00:00:00 2001 From: jtleider Date: Sat, 4 Aug 2018 18:18:22 -0500 Subject: [PATCH 4/4] Fix bug that was misaligning table with multiple groupby levels --- tableone.py | 2 +- test_tableone.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tableone.py b/tableone.py index 2a7a198..060688f 100644 --- a/tableone.py +++ b/tableone.py @@ -459,7 +459,7 @@ def _create_cat_describe(self,data): df_cat = pd.concat(group_dict,axis=1) # ensure the groups are the final levels of the column index if df_cat.columns.nlevels>1: - df_cat = df_cat.swaplevel(0, -1, axis=1).sort_index(axis=1,level=0) + df_cat = df_cat.reorder_levels([df_cat.columns.nlevels-1]+[i for i in range(df_cat.columns.nlevels-1)], axis=1).sort_index(axis=1,level=0) return df_cat diff --git a/test_tableone.py b/test_tableone.py index f40d011..fc790a4 100644 --- a/test_tableone.py +++ b/test_tableone.py @@ -549,10 +549,10 @@ def test_multilevel_groupby(self): ('Age', ''): [0, '65.29 (17.94)', '62.47 (16.65)', '71.06 (13.90)', '72.42 (14.21)'], ('Height', ''): [475, '171.55 (31.78)', '169.24 (11.06)', '167.36 (11.32)', '169.86 (11.34)'], ('Weight', ''): [302, '81.03 (22.28)', '85.02 (24.67)', '83.89 (28.35)', '80.44 (21.66)'], - ('ICU', 'CCU'): [0, '110 (23.5)', '11 (15.28)', '27 (6.82)', '14 (21.88)'], - ('ICU', 'CSRU'): ['', '50 (10.68)', '3 (4.17)', '144 (36.36)', '5 (7.81)'], - ('ICU', 'MICU'): ['', '205 (43.8)', '47 (65.28)', '113 (28.54)', '15 (23.44)'], - ('ICU', 'SICU'): ['', '103 (22.01)', '11 (15.28)', '112 (28.28)', '30 (46.88)'] + ('ICU', 'CCU'): [0, '110 (23.5)', '27 (6.82)', '11 (15.28)', '14 (21.88)'], + ('ICU', 'CSRU'): ['', '50 (10.68)', '144 (36.36)', '3 (4.17)', '5 (7.81)'], + ('ICU', 'MICU'): ['', '205 (43.8)', '113 (28.54)', '47 (65.28)', '15 (23.44)'], + ('ICU', 'SICU'): ['', '103 (22.01)', '112 (28.28)', '11 (15.28)', '30 (46.88)'] } for row in rows: assert list(table.tableone.loc[row]) == correct_value[row]