diff --git a/q2_longitudinal/_longitudinal.py b/q2_longitudinal/_longitudinal.py index 5c756aa..cc4b52e 100644 --- a/q2_longitudinal/_longitudinal.py +++ b/q2_longitudinal/_longitudinal.py @@ -26,6 +26,7 @@ _regplot_subplots_from_dataframe, _load_metadata, _validate_input_values, _validate_input_columns, _nmit, _validate_is_numeric_column, _maz_score, + _first_distances_and_distance_to_baseline, _first_differences, _importance_filtering, _summarize_feature_stats, _convert_nan_to_none, _parse_formula, _visualize_anova) @@ -610,10 +611,10 @@ def first_distances(distance_matrix: skbio.DistanceMatrix, _validate_input_columns( metadata, individual_id_column, None, state_column, None) - return _first_differences( - metadata, state_column, individual_id_column, metric=None, - replicate_handling=replicate_handling, baseline=baseline, - distance_matrix=distance_matrix) + return _first_distances_and_distance_to_baseline( + metadata, state_column, individual_id_column, + distance_matrix=distance_matrix, replicate_handling=replicate_handling, + baseline=baseline) def feature_volatility(ctx, diff --git a/q2_longitudinal/_utilities.py b/q2_longitudinal/_utilities.py index 041755b..f22416c 100644 --- a/q2_longitudinal/_utilities.py +++ b/q2_longitudinal/_utilities.py @@ -673,6 +673,174 @@ def _nmit(table, sample_md, individual_id_column, corr_method="kendall", return _dist +def _vectorized_first_distances( + distance_matrix: pd.DataFrame, metadata: pd.DataFrame, + state_column: str, individual_id_column: str) -> pd.Series: + + # sort sample names in a descending order so distances are labeled + # according to the sample_i+1 not sample_i + metadata.sort_values(by=[individual_id_column, state_column], + ascending=False, inplace=True) + + # when samples are sorted by subject and state the first distances are + # found in the second diagonal + distance_matrix = distance_matrix.loc[metadata.index, metadata.index] + distances = np.diag(distance_matrix.values, 1) + + # ignore the last element because that represents the first sample + # identifier + output = pd.DataFrame(index=metadata.index[:-1], columns=['Distance']) + output['Distance'] = distances + + # The next few columns are helpful to filter the results see below + output['subject_1'] = metadata.loc[ + output.index, individual_id_column].values + output['subject_2'] = metadata.loc[ + metadata.index[1:], individual_id_column].values + output['state_1'] = metadata.loc[ + output.index, state_column].values + output['state_2'] = metadata.loc[ + metadata.index[1:], state_column].values + + states = metadata[state_column].unique() + states.sort() + states = np.flip(states) + pairs = [(states[i], states[i+1]) for i in range(len(states) - 1)] + + def is_keepable_pair(row, reference): + return (row['state_1'], row['state_2']) in reference + + # Using the diagonal of the matrix will include some meaningless distances. + # For example the distance between the last sample of subject_j and the + # first sample of subject_j+1. We only keep the distances between subjects + # with the same identifier. + # + # Also, not all the distance pairs are allowed + output = output[ + (output['subject_1'] == output['subject_2']) & + output.apply(is_keepable_pair, axis=1, reference=pairs) + ]['Distance'] + + output.index.name = '#SampleID' + output = output.iloc[::-1] + return output + + +def _vectorized_distance_to_baseline( + distance_matrix: pd.DataFrame, metadata: pd.DataFrame, + state_column: str, individual_id_column: str) -> pd.Series: + + # with the sorted table, we add an extra column with indices so we make + # lookups by index not label, which should be faster + metadata.sort_values(by=[individual_id_column, state_column], + ascending=True, inplace=True) + index_name = _generate_column_name(metadata) + metadata[index_name] = np.arange(len(metadata)) + + # sort columns based on the metadata, and defer row sorting to the join + # operation + distance_matrix = distance_matrix[metadata.index] + metadata = metadata.join(distance_matrix, how='left') + + # there's four "utility" columns that need to be ignored: + # individual, state, combo, and index + def column_getter(frame): + loc = slice(4 + frame[index_name][0], 4 + frame[index_name][-1] + 1) + # set index to zero because we set the reference to be -Inf + return frame.iloc[0, loc].drop(frame.index[0]) + + output = metadata.groupby(individual_id_column).apply(column_getter) + + # When the output of groupby is a single series i.e. when there's only one + # individual, the return type is different: + # https://stackoverflow.com/q/37715246/379593 + if not isinstance(output, pd.Series): + output = output.stack() + output = output.reset_index(name='Distance') + + # first column represents the subject identifier + # second column is the sample identifier + # third column is the measured distance + return pd.Series( + index=pd.Index(output.iloc[:, 1].values, name='#SampleID'), + data=output.iloc[:, 2].values, name='Distance') + + +def _first_distances_and_distance_to_baseline(metadata, state_column, + individual_id_column, + distance_matrix, + replicate_handling='error', + baseline=None): + distance_matrix = distance_matrix.to_data_frame() + + metadata = metadata[[state_column, individual_id_column]] + metadata = metadata.loc[distance_matrix.index] + + # let's force states to be numeric + _validate_is_numeric_column(metadata, state_column) + + # combine individual and state to vectorize checks + combo_column = _generate_column_name(metadata) + metadata[combo_column] = (metadata[state_column].astype(float).astype(str) + + metadata[individual_id_column].astype(str)) + if replicate_handling == 'drop': + duplicated = metadata.duplicated(subset=combo_column, keep=False) + else: + # this way of finding duplicates is relevant for "random" too + duplicated = metadata.duplicated(subset=combo_column) + if replicate_handling == 'error': + if duplicated.any(): + def summarizer(group, id_column=None): + out = group[id_column].iloc[0] + out = '%s: %s' % (out, ', '.join(group[state_column])) + return out + messages = metadata.groupby( + individual_id_column + ).apply(summarizer, id_column=individual_id_column).tolist() + messages = "\n".join(messages) + raise ValueError('There are repeated states for the' + 'following individuals.\n' + f'{messages}') + + # depending on the strategy selected above filter the table + metadata = metadata[~duplicated].copy() + + # if calculating static differences, validate baseline as a valid state + if baseline is not None: + # convert baseline to the column's type + baseline = metadata[state_column].dtype.type(baseline) + + # check that baseline is present for all individuals + have_baseline = metadata.groupby( + individual_id_column).apply( + lambda group: baseline in group[state_column].values) + if not have_baseline.all(): + missing = ', '.join([str(i) for i in + have_baseline[~have_baseline].index]) + raise ValueError('baseline must be a valid state. The following ' + 'individuals are missing a baseline' + f' state value of "{baseline}": {missing}') + + # use -np.inf to always make the baseline the first in the list + metadata.replace({baseline: -np.inf}, inplace=True) + + output = _vectorized_distance_to_baseline( + distance_matrix, metadata, state_column, individual_id_column) + else: + output = _vectorized_first_distances( + distance_matrix, metadata, state_column, individual_id_column) + + output.dropna(inplace=True) + + if output.empty: + raise RuntimeError( + 'Output is empty. Either no paired samples were detected in the ' + 'inputs or replicate samples were dropped. Check input files, ' + 'parameters, and replicate_handling settings.') + + return output + + def _first_differences(metadata, state_column, individual_id_column, metric, replicate_handling='error', baseline=None, distance_matrix=None): diff --git a/q2_longitudinal/tests/test_longitudinal.py b/q2_longitudinal/tests/test_longitudinal.py index e70b471..9d02b5b 100644 --- a/q2_longitudinal/tests/test_longitudinal.py +++ b/q2_longitudinal/tests/test_longitudinal.py @@ -882,7 +882,7 @@ def test_first_distances_ecam(self): distance_matrix=self.md_ecam_dm, metadata=self.md_ecam_fp, state_column='month', individual_id_column='studyid', replicate_handling='drop') - pdt.assert_series_equal(obs, exp) + pdt.assert_series_equal(obs[exp.index], exp) def test_validate_metadata_is_superset_df(self): with self.assertRaisesRegex(ValueError, "Missing samples in metadata"):