diff --git a/src/pynwb/io/core.py b/src/pynwb/io/core.py index 90f06659e..59ba72fb5 100644 --- a/src/pynwb/io/core.py +++ b/src/pynwb/io/core.py @@ -42,9 +42,9 @@ def get_attr_value(self, **kwargs): ''' Get the value of the attribute corresponding to this spec from the given container ''' spec, container, manager = getargs('spec', 'container', 'manager', kwargs) - # handle custom mapping of container Units.waveform_rate -> spec Units.waveform_mean.sampling_rate + # handle custom mapping of Units waveform metadata onto waveform-bearing columns if isinstance(container.parent, Units): - if container.name == 'waveform_mean' or container.name == 'waveform_sd': + if container.name in ('waveform_mean', 'waveform_sd', 'waveforms'): if spec.name == 'sampling_rate': return container.parent.waveform_rate if spec.name == 'unit': diff --git a/src/pynwb/io/misc.py b/src/pynwb/io/misc.py index 20ec69261..f5d39b5ec 100644 --- a/src/pynwb/io/misc.py +++ b/src/pynwb/io/misc.py @@ -22,19 +22,17 @@ def waveform_unit_carg(self, builder, manager): return self._get_waveform_stat(builder, 'unit') def _get_waveform_stat(self, builder, attribute): - if 'waveform_mean' not in builder and 'waveform_sd' not in builder: + waveform_columns = ('waveform_mean', 'waveform_sd', 'waveforms') + stats = [builder[column].attributes.get(attribute) for column in waveform_columns if column in builder] + if not stats: return None - mean_stat = None - sd_stat = None - if 'waveform_mean' in builder: - mean_stat = builder['waveform_mean'].attributes.get(attribute) - if 'waveform_sd' in builder: - sd_stat = builder['waveform_sd'].attributes.get(attribute) - if mean_stat is not None and sd_stat is not None: - if mean_stat != sd_stat: - # throw warning - pass - return mean_stat + populated_stats = [stat for stat in stats if stat is not None] + if len(set(populated_stats)) > 1: + # throw warning + pass + if populated_stats: + return populated_stats[0] + return None @DynamicTableMap.object_attr("electrodes") def electrodes_column(self, container, manager): diff --git a/tests/integration/hdf5/test_misc.py b/tests/integration/hdf5/test_misc.py index ab9744c16..b07066a62 100644 --- a/tests/integration/hdf5/test_misc.py +++ b/tests/integration/hdf5/test_misc.py @@ -1,3 +1,4 @@ +import h5py import numpy as np from hdmf.common import VectorData, DynamicTableRegion @@ -70,6 +71,58 @@ def test_get_obs_intervals(self): np.testing.assert_array_equal(ut['obs_intervals'][:], [[[0., 1.], [2., 3.]], [[2., 5.], [6., 7.]]]) +class TestUnitsWaveformsOnlyIO(AcquisitionH5IOMixin, TestCase): + """Test roundtripping waveform metadata when only waveforms are present.""" + + def setUpContainer(self): + ut = Units(name='UnitsWaveformsOnlyTest', description='a simple table for testing Units waveforms') + ut.add_unit( + spike_times=[0., 1., 2.], + waveforms=[ + [ + [1, 2, 3], + [1, 2, 3], + [1, 2, 3] + ], [ + [1, 2, 3], + [1, 2, 3], + [1, 2, 3] + ] + ] + ) + ut.add_unit( + spike_times=[3., 4., 5.], + waveforms=np.array([ + [ + [1, 2, 3], + [1, 2, 3], + [1, 2, 3] + ], [ + [1, 2, 3], + [1, 2, 3], + [1, 2, 3] + ] + ]) + ) + ut.waveform_rate = 40000. + return ut + + def test_waveform_metadata_roundtrip(self): + ut = self.roundtripContainer() + self.assertEqual(ut.waveform_rate, 40000.) + self.assertEqual(ut.waveform_unit, 'volts') + + def test_waveforms_attributes_written(self): + self.roundtripContainer() + with h5py.File(self.filename, 'r') as infile: + waveforms = infile['acquisition'][self.container.name]['waveforms'] + self.assertEqual(waveforms.attrs['sampling_rate'], 40000.) + unit = waveforms.attrs['unit'] + if isinstance(unit, bytes): + unit = unit.decode('utf-8') + self.assertEqual(unit, 'volts') + + class TestUnitsFileIO(NWBH5IOMixin, TestCase): def setUpContainer(self):