Skip to content
131 changes: 92 additions & 39 deletions lightguide/blast.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from copy import deepcopy
from datetime import datetime, timedelta, timezone
from functools import wraps
import math
from typing import (
TYPE_CHECKING,
Any,
Expand All @@ -22,10 +21,11 @@
import numpy as np
from matplotlib import colors, dates
from matplotlib.colors import Colormap
from pyrocko import io, marker
from pyrocko import io, pile, obspy_compat, trace
from pyrocko.trace import Trace
from scipy import signal
import re
import math

from lightguide.utils import PathStr
from lightguide.models.picks import *
Expand Down Expand Up @@ -62,6 +62,7 @@ class Blast:

start_channel: int
channel_spacing: float
channel_list: np.ndarray

def __init__(
self,
Expand All @@ -70,6 +71,7 @@ def __init__(
sampling_rate: float,
start_channel: int = 0,
channel_spacing: float = 0.0,
channel_list: list = [],
unit: MeasurementUnit = "strain rate",
) -> None:
"""Create a new blast from NumPy array.
Expand All @@ -96,6 +98,10 @@ def __init__(

self.start_channel = start_channel
self.channel_spacing = channel_spacing
self.channel_list = channel_list

if len(self.channel_list) == 0:
self.channel_list = np.arange(start_channel, len(data), 1)

self.processing_flow = []

Expand All @@ -117,7 +123,7 @@ def n_channels(self) -> int:
@property
def end_channel(self) -> int:
"""End Channel."""
return self.start_channel + self.n_channels
return self.channel_list[-1]

@property
def n_samples(self) -> int:
Expand All @@ -129,18 +135,67 @@ def duration(self) -> float:
"""Duration in seconds."""
return self.n_samples * self.delta_t

def reduce_channels(self, n: int) -> None:
"""Returns sparsed blast containing only every n-th channel"""
self.data = self.data[:-1:n, :]
print(self.data.shape[:])
self.channel_spacing = self.channel_spacing * n
self.channel_list = self.channel_list[:-1:n]

def exlude_channel(self, channel) -> None:
"""Deletes selected channel, in-place.
Args:
channel (int): number of channel to be removed.
"""
idx = self.get_channel_index(channel, strict=True)
if idx == None:
print(f"#{channel} not in list.")
return self
Comment thread
juleluj marked this conversation as resolved.
Outdated
self.data = np.delete(self.data, idx, 0)
self.channel_list = np.delete(self.channel_list, idx)

def exlude_channels(self, channels) -> None:
"""Deletes channels given in list from blast, in-place."""
for channel in channels:
self.exlude_channel(channel=channel)

def get_channel_name(self, channel_index: int) -> int:
"""Gets name of channel from it's index as given in channel_list.
Args:
channel_index (int): index of channel of interest.
Returns:
int: Channel name.
"""
return self.channel_list[channel_index]

def get_channel_index(self, channel: int, strict=False) -> int:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Be always strict. Or is there advantage in raising an exception only sometimes?

"""Finds index of a given channel or channel closest to it and returns it.
Args:
channel (int): Channel name.
strict (bool): if False, return channel closest to channel
Returns:
int: channel index
"""
channels = self.channel_list
idx = (np.abs(channels - channel)).argmin()
# idx = np.searchsorted(channels, channel, side="left") # maybe faster??
if channels[idx] == channel:
return idx
elif strict == False:
print(f"#{channel} not in channel list. #{channels[idx]} is used instead.")
return idx
return None

def get_trace(self, channel: int) -> np.ndarray:
"""Get data from a singular channel.

Args:
channel (int): Channel number.
channel (int): Channel name.

Returns:
np.ndarray: 1D Trace.
"""
if not self.start_channel <= channel < self.end_channel:
raise ValueError(f"Channel {channel} is out of bounds")
return self.data[channel - self.start_channel]
return self.data[self.get_channel_index(channel)]

def _time_to_sample(self, time: datetime) -> int:
"""Get sample index for a time.
Expand Down Expand Up @@ -328,39 +383,14 @@ def afk_filter(
normalize_power=normalize_power,
)

def average_traces(self, no_of_traces, reduce_channels=False):
"""Average over number of neighbouring traces.
def average_traces(self, no_of_traces) -> Blast:
"""Average over number of neighbouring traces, in place.
Args:
no_of_traces (int): number of channels to be used for averaging
reduce_channels (bool): if True: returns list of averaged traces with no overlap (i.e. if no_of_traces=10, reurns traces 5,15,25...)
"""
blast = self.copy()
avs = []
d = deque(maxlen=no_of_traces)
for tr in blast.as_traces():
d.append(tr.ydata)
av = np.sum(d, axis=0) / len(d)
avs.append(av)
avs = np.array(avs)

if reduce_channels:
avs = avs[no_of_traces:-1:no_of_traces, :] # select
traces = []
channel = math.ceil(no_of_traces / 2)
for av in avs:
traces.append(
Trace(
ydata=av,
tmin=self.start_time.timestamp(),
deltat=self.delta_t,
station=f"{channel:05d}",
)
)
channel += no_of_traces
return traces

blast.data = avs
return blast
kernel = np.ones(shape=(no_of_traces, 1)) / no_of_traces
avs = signal.fftconvolve(self.data, kernel, mode="valid")
self.data = avs

def follow_phase(
self,
Expand Down Expand Up @@ -514,7 +544,10 @@ def trim_channels(self, begin: int = 0, end: int = -1) -> Blast:
Blast: Trimmed Blast.
"""
blast = self.copy()
begin = blast.get_channel_index(begin, strict=False)
end = blast.get_channel_index(end, strict=False)
blast.start_channel += begin
blast.channel_list = blast.channel_list[begin:end]
blast.data = blast.data[begin:end]
return blast

Expand Down Expand Up @@ -736,6 +769,20 @@ def as_traces(self) -> list[Trace]:
)
return traces

def to_obspy_stream(self):
"""Converts blast to an obspy stream

Returns:
Obspy stream containing traces of blast.
"""
p = pile.Pile()
p.add(self.as_traces())
return obspy_compat.to_obspy_stream(p)

def snuffle(self, **kwargs) -> None:
"""Show traces of blast in a snuffler window."""
trace.snuffle(self.as_traces(), **kwargs)

@classmethod
def from_pyrocko(cls, traces: list[Trace], channel_spacing: float = 4.0) -> Blast:
"""Create Blast from a list of Pyrocko traces.
Expand All @@ -755,6 +802,7 @@ def from_pyrocko(cls, traces: list[Trace], channel_spacing: float = 4.0) -> Blas
raise ValueError("Empty list of traces")

traces = sorted(traces, key=lambda tr: int(re.sub(r"\D", "", tr.station)))
channel_list = np.array([int(re.sub(r"\D", "", tr.station)) for tr in traces])
ntraces = len(traces)

tmin = set()
Expand Down Expand Up @@ -785,8 +833,10 @@ def from_pyrocko(cls, traces: list[Trace], channel_spacing: float = 4.0) -> Blas
data=data,
start_time=datetime.fromtimestamp(tmin.pop(), tz=timezone.utc),
sampling_rate=int(1.0 / delta_t.pop()),
start_channel=min(int(re.sub(r"\D", "", tr.station)) for tr in traces),
# start_channel=min(int(re.sub(r"\D", "", tr.station)) for tr in traces),
start_channel=channel_list[0],
channel_spacing=channel_spacing,
channel_list=channel_list,
)

@classmethod
Expand All @@ -804,7 +854,10 @@ def from_miniseed(cls, file: PathStr, channel_spacing: float = 4.0) -> Blast:
from pyrocko import io

traces = io.load(str(file), format="mseed")
return cls.from_pyrocko(traces, channel_spacing=channel_spacing)
return cls.from_pyrocko(
traces,
channel_spacing=channel_spacing,
)


TFun = TypeVar("TFun", bound=Callable[..., Any])
Expand Down