Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmd/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,7 @@ func buildAPIDependencies(

roleService := role.NewService(roleRepository, relationService, permissionService, auditRecordRepository, cfg.App.PAT.DeniedPermissionsSet())
policyService := policy.NewService(policyPGRepository, relationService, roleService)
userService := user.NewService(userRepository, relationService, policyService, roleService)
userService := user.NewService(userRepository, relationService, policyService, roleService, sessionService)
patValidator := userpat.NewValidator(logger, userPATRepo, cfg.App.PAT)
authnService := authenticate.NewService(logger, cfg.App.Authentication,
postgres.NewFlowRepository(logger, dbc), mailDialer, tokenService, sessionService, userService, serviceUserService, webAuthConfig, patValidator)
Expand Down
18 changes: 16 additions & 2 deletions core/authenticate/session/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,11 @@ package session
import (
"context"
"errors"
"log/slog"
"time"

"github.com/raystack/frontier/pkg/server/consts"

"log/slog"

"github.com/google/uuid"
"github.com/robfig/cron/v3"
"google.golang.org/grpc/metadata"
Expand Down Expand Up @@ -81,6 +80,21 @@ func (s Service) Delete(ctx context.Context, sessionID uuid.UUID) error {
return s.repo.Delete(ctx, sessionID)
}

// DeleteByUserID soft-deletes all active sessions belonging to a user.
// Iterates over the user's active sessions and revokes each via Delete.
func (s Service) DeleteByUserID(ctx context.Context, userID string) error {
sessions, err := s.repo.List(ctx, userID)
if err != nil {
return err
}
for _, sess := range sessions {
if err := s.Delete(ctx, sess.ID); err != nil {
return err
}
}
return nil
}

func (s Service) Get(ctx context.Context, sessionID uuid.UUID) (*Session, error) {
return s.repo.Get(ctx, sessionID)
}
Expand Down
56 changes: 56 additions & 0 deletions core/authenticate/session/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,62 @@ func TestService_Delete(t *testing.T) {
})
}

func TestService_DeleteByUserID(t *testing.T) {
userID := uuid.New().String()

t.Run("revokes each active session for the user", func(t *testing.T) {
mockRepository := mocks.NewRepository(t)
svc := session.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), mockRepository, 24*time.Hour)
sess1 := &session.Session{ID: uuid.New(), UserID: userID}
sess2 := &session.Session{ID: uuid.New(), UserID: userID}

mockRepository.On("List", mock.Anything, userID).Return([]*session.Session{sess1, sess2}, nil)
mockRepository.On("Delete", mock.Anything, sess1.ID).Return(nil)
mockRepository.On("Delete", mock.Anything, sess2.ID).Return(nil)

err := svc.DeleteByUserID(context.Background(), userID)

assert.Nil(t, err)
})

t.Run("returns nil when user has no active sessions", func(t *testing.T) {
mockRepository := mocks.NewRepository(t)
svc := session.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), mockRepository, 24*time.Hour)

mockRepository.On("List", mock.Anything, userID).Return([]*session.Session{}, nil)

err := svc.DeleteByUserID(context.Background(), userID)

assert.Nil(t, err)
})

t.Run("propagates list errors", func(t *testing.T) {
mockRepository := mocks.NewRepository(t)
svc := session.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), mockRepository, 24*time.Hour)

mockRepository.On("List", mock.Anything, userID).Return(nil, errors.New("db down"))

err := svc.DeleteByUserID(context.Background(), userID)

assert.ErrorContains(t, err, "db down")
})

t.Run("stops and returns error when an individual delete fails", func(t *testing.T) {
mockRepository := mocks.NewRepository(t)
svc := session.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), mockRepository, 24*time.Hour)
sess1 := &session.Session{ID: uuid.New(), UserID: userID}
sess2 := &session.Session{ID: uuid.New(), UserID: userID}

mockRepository.On("List", mock.Anything, userID).Return([]*session.Session{sess1, sess2}, nil)
mockRepository.On("Delete", mock.Anything, sess1.ID).Return(errors.New("revoke failed"))

err := svc.DeleteByUserID(context.Background(), userID)

assert.ErrorContains(t, err, "revoke failed")
mockRepository.AssertNotCalled(t, "Delete", mock.Anything, sess2.ID)
})
Comment thread
coderabbitai[bot] marked this conversation as resolved.
}

func TestService_ExtractFromContext(t *testing.T) {
t.Run("should be able to extract session from context if it is present", func(t *testing.T) {
mockRepository := mocks.NewRepository(t)
Expand Down
83 changes: 83 additions & 0 deletions core/user/mocks/session_service.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

14 changes: 12 additions & 2 deletions core/user/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,21 +41,28 @@ type RoleService interface {
List(ctx context.Context, f role.Filter) ([]role.Role, error)
}

type SessionService interface {
DeleteByUserID(ctx context.Context, userID string) error
}

type Service struct {
repository Repository
relationService RelationService
policyService PolicyService
roleService RoleService
sessionService SessionService
Now func() time.Time
}

func NewService(repository Repository, relationRepo RelationService,
policyService PolicyService, roleService RoleService) *Service {
policyService PolicyService, roleService RoleService,
sessionService SessionService) *Service {
return &Service{
repository: repository,
relationService: relationRepo,
policyService: policyService,
roleService: roleService,
sessionService: sessionService,
Now: func() time.Time {
return time.Now().UTC()
},
Expand Down Expand Up @@ -139,7 +146,10 @@ func (s Service) Disable(ctx context.Context, id string) error {
if !utils.IsValidUUID(id) {
return ErrInvalidID
}
return s.repository.SetState(ctx, id, Disabled)
if err := s.repository.SetState(ctx, id, Disabled); err != nil {
return err
}
return s.sessionService.DeleteByUserID(ctx, id)
Comment thread
rohilsurana marked this conversation as resolved.
}

// Delete by user uuid
Expand Down
Loading
Loading