Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions app/usr/internal/sto/accesshasher.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ func NewAccessHasher(kv storage.KV) *AccessHasher {
return &AccessHasher{kv: kv}
}

func (h *AccessHasher) SetChannelAccessHash(userID, channelID, accessHash int64) error {
data, err := h.kv.Get(context.TODO(), key.ChannelAccessHash(userID))
func (h *AccessHasher) SetChannelAccessHash(ctx context.Context, userID, channelID, accessHash int64) error {
data, err := h.kv.Get(ctx, key.ChannelAccessHash(userID))
if err != nil && !errors.Is(err, kv.ErrNotFound) {
return err
}
Expand All @@ -39,11 +39,11 @@ func (h *AccessHasher) SetChannelAccessHash(userID, channelID, accessHash int64)
return err
}

return h.kv.Set(context.TODO(), key.ChannelAccessHash(userID), string(b))
return h.kv.Set(ctx, key.ChannelAccessHash(userID), string(b))
}

func (h *AccessHasher) GetChannelAccessHash(userID, channelID int64) (int64, bool, error) {
data, err := h.kv.Get(context.TODO(), key.ChannelAccessHash(userID))
func (h *AccessHasher) GetChannelAccessHash(ctx context.Context, userID, channelID int64) (int64, bool, error) {
data, err := h.kv.Get(ctx, key.ChannelAccessHash(userID))
if err != nil {
if errors.Is(err, kv.ErrNotFound) {
return 0, false, nil
Expand Down
22 changes: 11 additions & 11 deletions app/usr/internal/sto/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func (s *State) Set(key string, v interface{}) error {
return s.kv.Set(context.TODO(), key, string(data))
}

func (s *State) GetState(userID int64) (updates.State, bool, error) {
func (s *State) GetState(ctx context.Context, userID int64) (updates.State, bool, error) {
state := updates.State{}

if err := s.Get(key.State(userID), &state); err != nil {
Expand All @@ -49,15 +49,15 @@ func (s *State) GetState(userID int64) (updates.State, bool, error) {
return state, true, nil
}

func (s *State) SetState(userID int64, state updates.State) error {
func (s *State) SetState(ctx context.Context, userID int64, state updates.State) error {
if err := s.Set(key.State(userID), state); err != nil {
return err
}

return s.Set(key.StateChannel(userID), struct{}{})
}

func (s *State) SetPts(userID int64, pts int) error {
func (s *State) SetPts(ctx context.Context, userID int64, pts int) error {
state, k := updates.State{}, key.State(userID)

if err := s.Get(k, &state); err != nil {
Expand All @@ -67,7 +67,7 @@ func (s *State) SetPts(userID int64, pts int) error {
return s.Set(k, state)
}

func (s *State) SetQts(userID int64, qts int) error {
func (s *State) SetQts(ctx context.Context, userID int64, qts int) error {
state, k := updates.State{}, key.State(userID)

if err := s.Get(k, &state); err != nil {
Expand All @@ -77,7 +77,7 @@ func (s *State) SetQts(userID int64, qts int) error {
return s.Set(k, state)
}

func (s *State) SetDate(userID int64, date int) error {
func (s *State) SetDate(ctx context.Context, userID int64, date int) error {
state, k := updates.State{}, key.State(userID)

if err := s.Get(k, &state); err != nil {
Expand All @@ -87,7 +87,7 @@ func (s *State) SetDate(userID int64, date int) error {
return s.Set(k, state)
}

func (s *State) SetSeq(userID int64, seq int) error {
func (s *State) SetSeq(ctx context.Context, userID int64, seq int) error {
state, k := updates.State{}, key.State(userID)

if err := s.Get(k, &state); err != nil {
Expand All @@ -97,7 +97,7 @@ func (s *State) SetSeq(userID int64, seq int) error {
return s.Set(k, state)
}

func (s *State) SetDateSeq(userID int64, date, seq int) error {
func (s *State) SetDateSeq(ctx context.Context, userID int64, date, seq int) error {
state, k := updates.State{}, key.State(userID)

if err := s.Get(k, &state); err != nil {
Expand All @@ -108,7 +108,7 @@ func (s *State) SetDateSeq(userID int64, date, seq int) error {
return s.Set(k, state)
}

func (s *State) GetChannelPts(userID, channelID int64) (int, bool, error) {
func (s *State) GetChannelPts(ctx context.Context, userID, channelID int64) (int, bool, error) {
c := make(map[int64]int)

if err := s.Get(key.StateChannel(userID), &c); err != nil {
Expand All @@ -126,7 +126,7 @@ func (s *State) GetChannelPts(userID, channelID int64) (int, bool, error) {
return pts, true, nil
}

func (s *State) SetChannelPts(userID, channelID int64, pts int) error {
func (s *State) SetChannelPts(ctx context.Context, userID, channelID int64, pts int) error {
c, k := make(map[int64]int), key.StateChannel(userID)

if err := s.Get(k, &c); err != nil {
Expand All @@ -136,15 +136,15 @@ func (s *State) SetChannelPts(userID, channelID int64, pts int) error {
return s.Set(k, c)
}

func (s *State) ForEachChannels(userID int64, f func(channelID int64, pts int) error) error {
func (s *State) ForEachChannels(ctx context.Context, userID int64, f func(context.Context, int64, int) error) error {
c := make(map[int64]int)

if err := s.Get(key.StateChannel(userID), &c); err != nil {
return err
}

for channelID, pts := range c {
if err := f(channelID, pts); err != nil {
if err := f(ctx, channelID, pts); err != nil {
return err
}
}
Expand Down
18 changes: 6 additions & 12 deletions app/usr/run/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,14 @@ func Run(ctx context.Context, cfg string) error {
dispatcher := tg.NewUpdateDispatcher()
handleUsr(&dispatcher)

state := sto.NewState(kv)
hasher := sto.NewAccessHasher(kv)

gaps := updates.New(updates.Config{
Handler: dispatcher,
Logger: slog.Named("updates").Desugar(),
Storage: sto.NewState(kv),
AccessHasher: sto.NewAccessHasher(kv),
Storage: state,
AccessHasher: hasher,
OnChannelTooLong: func(channelID int64) {
slog.Errorw("channel is too long", "channelID", channelID)
},
Expand Down Expand Up @@ -133,16 +136,7 @@ func Run(ctx context.Context, cfg string) error {
go bot.Start()
defer bot.Stop()

// notify update manager about authentication.
// isBot set to `true` to avoid too long updates diff and can't fetch old messages.
// forgot set to `false` to avoid replace local pts by remote pts.
if err := gaps.Auth(ctx, c.API(), status.User.ID, true, false); err != nil {
return err
}
defer func() { _ = gaps.Logout() }()

color.Blue("Auth successfully! User: %s", status.User.Username)

// gaps will automatically handle authentication and state management
<-ctx.Done()
return ctx.Err()
})
Expand Down