diff --git a/src/graphnet/data/extractors/icecube/i3highesteparticleextractor.py b/src/graphnet/data/extractors/icecube/i3highesteparticleextractor.py index 19228d21c..455ff2e94 100644 --- a/src/graphnet/data/extractors/icecube/i3highesteparticleextractor.py +++ b/src/graphnet/data/extractors/icecube/i3highesteparticleextractor.py @@ -126,11 +126,17 @@ def __call__(self, frame: "icetray.I3Frame") -> Dict[str, Any]: ) except IndexError: parent_type = 0 - if primary_energy > 0: primary_fraction = EonEntrance / primary_energy else: primary_fraction = -1 + + if visible_length != -1: + primary = frame[self.mctree].get_primary(HEParticle.id) + primary_is_nu, primary_type = primary.is_neutrino, primary.type + else: + primary_type = 0 + primary_is_nu = False output.update( { "e_fraction_" + self._extractor_name: primary_fraction, @@ -153,6 +159,8 @@ def __call__(self, frame: "icetray.I3Frame") -> Dict[str, Any]: "particle_type_" + self._extractor_name: HEParticle.type, "containment_" + self._extractor_name: containment, "parent_type_" + self._extractor_name: parent_type, + "primary_type_" + self._extractor_name: primary_type, + "primary_is_nu_" + self._extractor_name: primary_is_nu, } ) @@ -182,7 +190,9 @@ def get_tracks( primaries = [self.check_primary_energy(frame, p) for p in primaries] MMCTrackList = frame[self.mmctracklist] - if self.daughters: + if self.daughters & ( + not self._is_corsika + ): # expensive operation unecessary for CORSIKA temp_MMCTrackList = [] for track in MMCTrackList: for p in primaries: @@ -192,6 +202,26 @@ def get_tracks( temp_MMCTrackList.append(track) break MMCTrackList = simclasses.I3MMCTrackList(temp_MMCTrackList) + elif self._is_corsika & self.daughters: + MMCTrackList_filtered = [] + for track in MMCTrackList: + try: + if ( + frame[self.mctree].get_primary(track.GetI3Particle()) + in primaries + ): + MMCTrackList_filtered.append(track) + except RuntimeError as e: + if "particle not found" in str(e): + # get event header + self.warning( + f"Particle {track.GetI3Particle().id} not found in MCTree." + f" Skipping track in event {frame['I3EventHeader']}" + ) + else: + raise e # re-raise unexpected errors + + MMCTrackList = simclasses.I3MMCTrackList(MMCTrackList_filtered) MuonGun_tracks = np.array( MuonGun.Track.harvest(frame[self.mctree], MMCTrackList) @@ -347,14 +377,6 @@ def highest_energy_track( if tmp_EonEntrance > EonEntrance: particle = track_particle - closest_pos = np.array( - [ - track.GetXc(), - track.GetYc(), - track.GetZc(), - ] - ) - EonEntrance = tmp_EonEntrance visible_length = intersections.second - max( @@ -391,6 +413,13 @@ def highest_energy_track( ) particle.time = track.GetTi() else: + closest_pos = np.array( + [ + track.GetXc(), + track.GetYc(), + track.GetZc(), + ] + ) # If the track is stopping or throughgoing, # pos is point closest to detector center. distance = np.sqrt((closest_pos**2).sum()) @@ -718,7 +747,7 @@ def highest_energy_bundle( lengths = lengths[length_mask] containment = GN_containment_types.stopping_bundle.value - closest_pos = [] + highest_e = 0 for track, MGtrack in zip(MMCTrackList, MuonGun_tracks): intersections = self.hull.surface.intersection( MGtrack.pos, MGtrack.dir @@ -751,30 +780,24 @@ def highest_energy_bundle( raise # re-raise unexpected errors EonEntrance += track_energy - - closest_pos.append( - np.array( + if track_energy > highest_e: + highest_e = track_energy + closest_pos = np.array( [ track.GetXc(), track.GetYc(), track.GetZc(), ] ) - * track_energy - ) - if closest_time is None: - closest_time = track.GetTc() - elif closest_time < track.GetTc(): + closest_time = track.GetTc() - if intersections.second > 0: - visible_length = max( - visible_length, intersections.second - intersections.first - ) - if MGtrack.length > intersections.second: - containment = ( - GN_containment_types.throughgoing_bundle.value - ) + if intersections.second > 0: + visible_length = intersections.second - intersections.first + if MGtrack.length > intersections.second: + containment = ( + GN_containment_types.throughgoing_bundle.value + ) # If no intersection.second is every positive # the visible_length can still be negative here @@ -793,8 +816,6 @@ def highest_energy_bundle( visible_length >= 0 ), f"Visible length is negative for particle {frame['I3EventHeader']}" - closest_pos = np.sum(closest_pos, axis=0) / EonEntrance - bundle.pos = dataclasses.I3Position( closest_pos[0], closest_pos[1], closest_pos[2] )