diff --git a/tableone.py b/tableone.py index 9b75a03..060688f 100644 --- a/tableone.py +++ b/tableone.py @@ -33,8 +33,8 @@ class TableOne(object): List of columns in the dataset to be included in the final table. categorical : list, optional List of columns that contain categorical variables. - groupby : str, optional - Optional column for stratifying the final table (default: None). + groupby : list, optional + Optional columns for stratifying the final table (default: None). nonnormal : list, optional List of columns that contain non-normal variables (default: None). pval : bool, optional @@ -83,9 +83,9 @@ def __init__(self, data, columns=None, categorical=None, groupby=None, # check input arguments if not groupby: - groupby = '' - elif groupby and type(groupby) == list: - groupby = groupby[0] + groupby = [] + elif groupby and isinstance(groupby, str): + groupby = [groupby] if not nonnormal: nonnormal=[] @@ -115,7 +115,7 @@ def __init__(self, data, columns=None, categorical=None, groupby=None, self._columns = list(columns) self._isnull = isnull - self._continuous = [c for c in columns if c not in categorical + [groupby]] + self._continuous = [c for c in columns if c not in categorical + groupby] self._categorical = categorical self._nonnormal = nonnormal self._pval = pval @@ -131,13 +131,15 @@ def __init__(self, data, columns=None, categorical=None, groupby=None, # output column names that cannot be contained in a groupby self._reserved_columns = ['isnull', 'pval', 'ptest', 'pval (adjusted)'] if self._groupby: - self._groupbylvls = sorted(data.groupby(groupby).groups.keys()) + for groupbyvar in groupby: + data[groupbyvar] = data[groupbyvar].astype(str) # Treat groupby variables as string to avoid problems with categorical groupby + self._groups = data.groupby(groupby).groups # check that the group levels do not include reserved words - for level in self._groupbylvls: + for level in data.groupby(self._groupby[0]).groups: if level in self._reserved_columns: raise InputError('Group level contained "{}", a reserved keyword for tableone.'.format(level)) else: - self._groupbylvls = ['overall'] + self._groups = {'overall': data.index} # forgive me jraffa if self._pval: @@ -380,13 +382,15 @@ def _create_cont_describe(self,data): if self._groupby: # add the groupby column back - cont_data = cont_data.merge(data[[self._groupby]], + cont_data = cont_data.merge(data[self._groupby], left_index=True, right_index=True) # group and aggregate data df_cont = pd.pivot_table(cont_data, - columns=[self._groupby], + columns=self._groupby, aggfunc=aggfuncs) + if len(self._groupby) > 1: + df_cont = df_cont.unstack([i+1 for i in range(len(self._groupby))]) else: # if no groupby, just add single group column df_cont = cont_data.apply(aggfuncs).T @@ -419,11 +423,8 @@ def _create_cat_describe(self,data): """ group_dict = {} - for g in self._groupbylvls: - if self._groupby: - d_slice = data.loc[data[self._groupby] == g, self._categorical] - else: - d_slice = data[self._categorical].copy() + for g, gdata in self._groups.items(): + d_slice = data.loc[gdata, self._categorical].copy() # create a dataframe with freq, proportion df = d_slice.copy() @@ -456,9 +457,9 @@ def _create_cat_describe(self,data): group_dict[g] = df df_cat = pd.concat(group_dict,axis=1) - # ensure the groups are the 2nd level of the column index + # ensure the groups are the final levels of the column index if df_cat.columns.nlevels>1: - df_cat = df_cat.swaplevel(0, 1, axis=1).sort_index(axis=1,level=0) + df_cat = df_cat.reorder_levels([df_cat.columns.nlevels-1]+[i for i in range(df_cat.columns.nlevels-1)], axis=1).sort_index(axis=1,level=0) return df_cat @@ -495,8 +496,8 @@ def _create_significance_table(self,data): if is_continuous: catlevels = None grouped_data = [] - for s in self._groupbylvls: - lvl_data = data.loc[data[self._groupby]==s, v] + for g, gdata in self._groups.items(): + lvl_data = data.loc[gdata, v] # coerce to numeric and drop non-numeric data lvl_data = lvl_data.apply(pd.to_numeric, errors='coerce').dropna() # append to overall group data @@ -505,7 +506,7 @@ def _create_significance_table(self,data): # if categorical, create contingency table elif is_categorical: catlevels = sorted(data[v].astype('category').cat.categories) - grouped_data = pd.crosstab(data[self._groupby].rename('_groupby_var_'),data[v]) + grouped_data = pd.crosstab([data[g] for g in self._groupby],data[v]) min_observed = grouped_data.sum(axis=1).min() # minimum number of observations across all levels @@ -516,6 +517,7 @@ def _create_significance_table(self,data): grouped_data,is_continuous,is_categorical, is_normal,min_observed,catlevels) + if len(self._groupby) > 1: df.columns = pd.MultiIndex.from_product([df.columns if i == 0 else [''] for i in range(len(self._groupby))]) return df def _p_test(self,v,grouped_data,is_continuous,is_categorical, @@ -600,12 +602,10 @@ def _create_cont_table(self,data): table.columns = table.columns.droplevel(level=0) # add a column of null counts as 1-count() from previous function + # isnull needs to be its own column nulltable = data[self._continuous].isnull().sum().to_frame(name='isnull') - try: - table = table.join(nulltable) - except TypeError: # if columns form a CategoricalIndex, need to convert to string first - table.columns = table.columns.astype(str) - table = table.join(nulltable) + if len(self._groupby) > 1: nulltable.columns = pd.MultiIndex.from_product([['isnull'] if i == 0 else [''] for i in range(len(self._groupby))]) + table = table.join(nulltable) # add an empty level column, for joining with cat table table['level'] = '' @@ -632,11 +632,8 @@ def _create_cat_table(self,data): # add the total count of null values across all levels isnull = data[self._categorical].isnull().sum().to_frame(name='isnull') isnull.index.rename('variable', inplace=True) - try: - table = table.join(isnull) - except TypeError: # if columns form a CategoricalIndex, need to convert to string first - table.columns = table.columns.astype(str) - table = table.join(isnull) + if len(self._groupby) > 1: isnull.columns = pd.MultiIndex.from_product([['isnull'] if i == 0 else [''] for i in range(len(self._groupby))]) + table = table.join(isnull) # add pval column if self._pval and self._pval_adjust: @@ -699,13 +696,15 @@ def _create_tableone(self,data): n_row = pd.DataFrame(columns = ['variable','level','isnull']) n_row.set_index(['variable','level'], inplace=True) n_row.loc['n', ''] = None + if len(self._groupby) > 1: n_row.columns = pd.MultiIndex.from_tuples( + [tuple('isnull' if i == 0 else '' for i in range(len(self._groupby))), tuple('' for i in range(len(self._groupby)))], names=table.columns.names) table = pd.concat([n_row,table],sort=False) - if self._groupbylvls == ['overall']: + if not self._groupby: table.loc['n','overall'] = len(data.index) else: - for g in self._groupbylvls: - ct = data[self._groupby][data[self._groupby]==g].count() + for g, gdata in self._groups.items(): + ct = len(gdata) table.loc['n',g] = ct # only display data in first level row @@ -716,28 +715,32 @@ def _create_tableone(self,data): if col in table.columns.values: dupe_columns.append(col) + if len(self._groupby) > 1: dupe_columns = [tuple(c if i == 0 else '' for i in range(len(self._groupby))) for c in dupe_columns] table[dupe_columns] = table[dupe_columns].mask(dupe_mask).fillna('') # remove empty column added above - table.drop([''], axis=1, inplace=True) + if len(self._groupby) > 1: table.drop(tuple('' for i in range(len(self._groupby))), axis=1, inplace=True) + else: table.drop('', axis=1, inplace=True) # remove isnull column if not needed if not self._isnull: - table.drop('isnull',axis=1,inplace=True) + if len(self._groupby) > 1: table.drop(tuple('isnull' if i == 0 else '' for i in range(len(self._groupby))),axis=1,inplace=True) + else: table.drop('isnull',axis=1,inplace=True) # replace nans with empty strings table.fillna('',inplace=True) # add column index - if not self._groupbylvls == ['overall']: + if self._groupby: # rename groupby variable if requested - c = self._groupby - if self._alt_labels: - if self._groupby in self._alt_labels: - c = self._alt_labels[self._groupby] + if self._alt_labels: c = ', '.join([self._alt_labels.get(g, g) for g in self._groupby]) + else: c = ', '.join(self._groupby) c = 'Grouped by {}'.format(c) - table.columns = pd.MultiIndex.from_product([[c], table.columns]) + if len(self._groupby) > 1: + table.columns = pd.MultiIndex.from_tuples(tuple([c])+table.columns[i] for i in range(len(table.columns))) + else: + table.columns = pd.MultiIndex.from_product([[c], table.columns]) # display alternative labels if assigned table.rename(index=self._create_row_labels(), inplace=True, level=0) diff --git a/test_tableone.py b/test_tableone.py index 5041d3c..fc790a4 100644 --- a/test_tableone.py +++ b/test_tableone.py @@ -525,3 +525,94 @@ def test_check_null_counts_are_correct_pn(self): # check each null count is correct col = isnull.index[i][0] assert self.data_pn[col].isnull().sum() == v + + @with_setup(setup, teardown) + def test_multilevel_groupby(self): + """ + Test multilevel groupby produces expected results + """ + columns = ['Age', 'Height', 'Weight', 'ICU'] + categorical = ['ICU'] + + table = TableOne(self.data_pn, columns=columns, categorical=categorical, groupby=['death', 'MechVent']) + assert table.tableone.columns[0][0] == 'Grouped by death, MechVent' + table.tableone.columns = table.tableone.columns.droplevel(0) + assert len(table.tableone.columns) == 5 + for i, correct_col in enumerate([('isnull', ''), ('0', '0'), ('0', '1'), ('1', '0'), ('1', '1')]): + assert table.tableone.columns[i] == correct_col + assert len(table.tableone.index) == 8 + rows = [('n', ''), ('Age', ''), ('Height', ''), ('Weight', ''), ('ICU', 'CCU'), ('ICU', 'CSRU'), ('ICU', 'MICU'), ('ICU', 'SICU')] + for i, correct_row in enumerate(rows): + assert table.tableone.index[i] == correct_row + correct_value = { + ('n', ''): ['', 468, 396, 72, 64], + ('Age', ''): [0, '65.29 (17.94)', '62.47 (16.65)', '71.06 (13.90)', '72.42 (14.21)'], + ('Height', ''): [475, '171.55 (31.78)', '169.24 (11.06)', '167.36 (11.32)', '169.86 (11.34)'], + ('Weight', ''): [302, '81.03 (22.28)', '85.02 (24.67)', '83.89 (28.35)', '80.44 (21.66)'], + ('ICU', 'CCU'): [0, '110 (23.5)', '27 (6.82)', '11 (15.28)', '14 (21.88)'], + ('ICU', 'CSRU'): ['', '50 (10.68)', '144 (36.36)', '3 (4.17)', '5 (7.81)'], + ('ICU', 'MICU'): ['', '205 (43.8)', '113 (28.54)', '47 (65.28)', '15 (23.44)'], + ('ICU', 'SICU'): ['', '103 (22.01)', '112 (28.28)', '11 (15.28)', '30 (46.88)'] + } + for row in rows: + assert list(table.tableone.loc[row]) == correct_value[row] + + @with_setup(setup, teardown) + def test_multilevel_groupby_pval(self): + """ + Test multilevel groupby works when p-values are requested + """ + columns = ['Age', 'Height', 'Weight', 'ICU'] + categorical = ['ICU'] + + table = TableOne(self.data_pn, columns=columns, categorical=categorical, groupby=['death', 'MechVent'], pval=True) + table = TableOne(self.data_pn, columns=columns, categorical=categorical, groupby=['death', 'MechVent'], pval=True, pval_adjust='bonferroni') + table = TableOne(self.data_pn, columns=columns, categorical=categorical, groupby=['death', 'MechVent'], pval=True, nonnormal=['Age']) + assert table.tableone.loc['Weight', ('Grouped by death, MechVent', 'pval', '')][0] == '0.187' + + @with_setup(setup, teardown) + def test_multilevel_groupby_noisnull(self): + """ + Test multilevel groupby runs without error when isnull option is False + """ + columns = ['Age', 'Height', 'Weight', 'ICU'] + categorical = ['ICU'] + + table = TableOne(self.data_pn, columns=columns, categorical=categorical, groupby=['death', 'MechVent'], isnull=False) + + @with_setup(setup, teardown) + def test_multilevel_groupby_sort(self): + """ + Test multilevel groupby runs without error when sort option is True + """ + columns = ['Age', 'Height', 'Weight', 'ICU'] + categorical = ['ICU'] + + table = TableOne(self.data_pn, columns=columns, categorical=categorical, groupby=['death', 'MechVent'], sort=True) + + @with_setup(setup, teardown) + def test_multilevel_groupby_limit(self): + """ + Test multilevel groupby runs correctly when limit option is set + """ + columns = ['Age', 'Height', 'Weight', 'ICU'] + categorical = ['ICU'] + + table = TableOne(self.data_pn, columns=columns, categorical=categorical, groupby=['death', 'MechVent'], limit=2) + assert list(table.tableone.loc['ICU'].index) == ['MICU', 'SICU'] + + @with_setup(setup, teardown) + def test_groupby_categorical(self): + """ + Test groupby runs without error with categorical groupby variable + """ + columns = ['Age', 'Height', 'Weight', 'ICU'] + categorical = ['ICU'] + + pn = self.data_pn.copy() + pn['death'] = pn['death'].astype('category') + table = TableOne(pn, columns=columns, categorical=categorical, groupby=['death']) + assert len(table.tableone.columns == 3) + table = TableOne(pn, columns=columns, categorical=categorical, groupby=['death', 'MechVent']) + assert len(table.tableone.columns) == 5 +