From 79a1934713f279270fc2b400d391ccf7fc6b3ac4 Mon Sep 17 00:00:00 2001 From: 2598142880 <140939964+2598142880@users.noreply.github.com> Date: Mon, 10 Feb 2025 03:19:27 +0800 Subject: [PATCH 1/3] Update accesshasher.go --- app/usr/internal/sto/accesshasher.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/app/usr/internal/sto/accesshasher.go b/app/usr/internal/sto/accesshasher.go index 4c0b36c..287577e 100644 --- a/app/usr/internal/sto/accesshasher.go +++ b/app/usr/internal/sto/accesshasher.go @@ -17,8 +17,8 @@ func NewAccessHasher(kv storage.KV) *AccessHasher { return &AccessHasher{kv: kv} } -func (h *AccessHasher) SetChannelAccessHash(userID, channelID, accessHash int64) error { - data, err := h.kv.Get(context.TODO(), key.ChannelAccessHash(userID)) +func (h *AccessHasher) SetChannelAccessHash(ctx context.Context, userID, channelID, accessHash int64) error { + data, err := h.kv.Get(ctx, key.ChannelAccessHash(userID)) if err != nil && !errors.Is(err, kv.ErrNotFound) { return err } @@ -39,11 +39,11 @@ func (h *AccessHasher) SetChannelAccessHash(userID, channelID, accessHash int64) return err } - return h.kv.Set(context.TODO(), key.ChannelAccessHash(userID), string(b)) + return h.kv.Set(ctx, key.ChannelAccessHash(userID), string(b)) } -func (h *AccessHasher) GetChannelAccessHash(userID, channelID int64) (int64, bool, error) { - data, err := h.kv.Get(context.TODO(), key.ChannelAccessHash(userID)) +func (h *AccessHasher) GetChannelAccessHash(ctx context.Context, userID, channelID int64) (int64, bool, error) { + data, err := h.kv.Get(ctx, key.ChannelAccessHash(userID)) if err != nil { if errors.Is(err, kv.ErrNotFound) { return 0, false, nil From 7c4e8c8523c0d9b24c783181539890b17442ecf0 Mon Sep 17 00:00:00 2001 From: 2598142880 <140939964+2598142880@users.noreply.github.com> Date: Mon, 10 Feb 2025 03:19:38 +0800 Subject: [PATCH 2/3] Update state.go --- app/usr/internal/sto/state.go | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/app/usr/internal/sto/state.go b/app/usr/internal/sto/state.go index 3ddb7ba..07da2fc 100644 --- a/app/usr/internal/sto/state.go +++ b/app/usr/internal/sto/state.go @@ -36,7 +36,7 @@ func (s *State) Set(key string, v interface{}) error { return s.kv.Set(context.TODO(), key, string(data)) } -func (s *State) GetState(userID int64) (updates.State, bool, error) { +func (s *State) GetState(ctx context.Context, userID int64) (updates.State, bool, error) { state := updates.State{} if err := s.Get(key.State(userID), &state); err != nil { @@ -49,7 +49,7 @@ func (s *State) GetState(userID int64) (updates.State, bool, error) { return state, true, nil } -func (s *State) SetState(userID int64, state updates.State) error { +func (s *State) SetState(ctx context.Context, userID int64, state updates.State) error { if err := s.Set(key.State(userID), state); err != nil { return err } @@ -57,7 +57,7 @@ func (s *State) SetState(userID int64, state updates.State) error { return s.Set(key.StateChannel(userID), struct{}{}) } -func (s *State) SetPts(userID int64, pts int) error { +func (s *State) SetPts(ctx context.Context, userID int64, pts int) error { state, k := updates.State{}, key.State(userID) if err := s.Get(k, &state); err != nil { @@ -67,7 +67,7 @@ func (s *State) SetPts(userID int64, pts int) error { return s.Set(k, state) } -func (s *State) SetQts(userID int64, qts int) error { +func (s *State) SetQts(ctx context.Context, userID int64, qts int) error { state, k := updates.State{}, key.State(userID) if err := s.Get(k, &state); err != nil { @@ -77,7 +77,7 @@ func (s *State) SetQts(userID int64, qts int) error { return s.Set(k, state) } -func (s *State) SetDate(userID int64, date int) error { +func (s *State) SetDate(ctx context.Context, userID int64, date int) error { state, k := updates.State{}, key.State(userID) if err := s.Get(k, &state); err != nil { @@ -87,7 +87,7 @@ func (s *State) SetDate(userID int64, date int) error { return s.Set(k, state) } -func (s *State) SetSeq(userID int64, seq int) error { +func (s *State) SetSeq(ctx context.Context, userID int64, seq int) error { state, k := updates.State{}, key.State(userID) if err := s.Get(k, &state); err != nil { @@ -97,7 +97,7 @@ func (s *State) SetSeq(userID int64, seq int) error { return s.Set(k, state) } -func (s *State) SetDateSeq(userID int64, date, seq int) error { +func (s *State) SetDateSeq(ctx context.Context, userID int64, date, seq int) error { state, k := updates.State{}, key.State(userID) if err := s.Get(k, &state); err != nil { @@ -108,7 +108,7 @@ func (s *State) SetDateSeq(userID int64, date, seq int) error { return s.Set(k, state) } -func (s *State) GetChannelPts(userID, channelID int64) (int, bool, error) { +func (s *State) GetChannelPts(ctx context.Context, userID, channelID int64) (int, bool, error) { c := make(map[int64]int) if err := s.Get(key.StateChannel(userID), &c); err != nil { @@ -126,7 +126,7 @@ func (s *State) GetChannelPts(userID, channelID int64) (int, bool, error) { return pts, true, nil } -func (s *State) SetChannelPts(userID, channelID int64, pts int) error { +func (s *State) SetChannelPts(ctx context.Context, userID, channelID int64, pts int) error { c, k := make(map[int64]int), key.StateChannel(userID) if err := s.Get(k, &c); err != nil { @@ -136,7 +136,7 @@ func (s *State) SetChannelPts(userID, channelID int64, pts int) error { return s.Set(k, c) } -func (s *State) ForEachChannels(userID int64, f func(channelID int64, pts int) error) error { +func (s *State) ForEachChannels(ctx context.Context, userID int64, f func(context.Context, int64, int) error) error { c := make(map[int64]int) if err := s.Get(key.StateChannel(userID), &c); err != nil { @@ -144,7 +144,7 @@ func (s *State) ForEachChannels(userID int64, f func(channelID int64, pts int) e } for channelID, pts := range c { - if err := f(channelID, pts); err != nil { + if err := f(ctx, channelID, pts); err != nil { return err } } From 3c289344e93c115ade67956de85b9574b5c1deff Mon Sep 17 00:00:00 2001 From: 2598142880 <140939964+2598142880@users.noreply.github.com> Date: Mon, 10 Feb 2025 03:19:54 +0800 Subject: [PATCH 3/3] Update run.go --- app/usr/run/run.go | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/app/usr/run/run.go b/app/usr/run/run.go index 9edc983..3b88200 100644 --- a/app/usr/run/run.go +++ b/app/usr/run/run.go @@ -88,11 +88,14 @@ func Run(ctx context.Context, cfg string) error { dispatcher := tg.NewUpdateDispatcher() handleUsr(&dispatcher) + state := sto.NewState(kv) + hasher := sto.NewAccessHasher(kv) + gaps := updates.New(updates.Config{ Handler: dispatcher, Logger: slog.Named("updates").Desugar(), - Storage: sto.NewState(kv), - AccessHasher: sto.NewAccessHasher(kv), + Storage: state, + AccessHasher: hasher, OnChannelTooLong: func(channelID int64) { slog.Errorw("channel is too long", "channelID", channelID) }, @@ -133,16 +136,7 @@ func Run(ctx context.Context, cfg string) error { go bot.Start() defer bot.Stop() - // notify update manager about authentication. - // isBot set to `true` to avoid too long updates diff and can't fetch old messages. - // forgot set to `false` to avoid replace local pts by remote pts. - if err := gaps.Auth(ctx, c.API(), status.User.ID, true, false); err != nil { - return err - } - defer func() { _ = gaps.Logout() }() - - color.Blue("Auth successfully! User: %s", status.User.Username) - + // gaps will automatically handle authentication and state management <-ctx.Done() return ctx.Err() })