diff --git a/cmd/serve.go b/cmd/serve.go index e51ad7ae2..61dd1a2c8 100644 --- a/cmd/serve.go +++ b/cmd/serve.go @@ -432,10 +432,11 @@ func buildAPIDependencies( authnService, serviceUserService, groupService, roleService) membershipService := membership.NewService(logger, policyService, relationService, roleService, organizationService, userService, projectService, groupService, serviceUserService, auditRecordRepository) - // Setter injection: org → membership is circular (membership needs org for validation, - // org needs membership for Create/AdminCreate). Break the cycle with a post-init setter. + // Setter injection: org/group → membership is circular (membership needs them + // for validation; they need membership for Create). Break the cycle post-init. organizationService.SetMembershipService(membershipService) serviceUserService.SetMembershipService(membershipService) + groupService.SetMembershipService(membershipService) orgKycRepository := postgres.NewOrgKycRepository(dbc) orgKycService := kyc.NewService(orgKycRepository) diff --git a/core/group/mocks/membership_service.go b/core/group/mocks/membership_service.go new file mode 100644 index 000000000..b425239ae --- /dev/null +++ b/core/group/mocks/membership_service.go @@ -0,0 +1,86 @@ +// Code generated by mockery v2.53.5. DO NOT EDIT. + +package mocks + +import ( + context "context" + + mock "github.com/stretchr/testify/mock" +) + +// MembershipService is an autogenerated mock type for the MembershipService type +type MembershipService struct { + mock.Mock +} + +type MembershipService_Expecter struct { + mock *mock.Mock +} + +func (_m *MembershipService) EXPECT() *MembershipService_Expecter { + return &MembershipService_Expecter{mock: &_m.Mock} +} + +// OnGroupCreated provides a mock function with given fields: ctx, groupID, orgID, creatorID, creatorType +func (_m *MembershipService) OnGroupCreated(ctx context.Context, groupID string, orgID string, creatorID string, creatorType string) error { + ret := _m.Called(ctx, groupID, orgID, creatorID, creatorType) + + if len(ret) == 0 { + panic("no return value specified for OnGroupCreated") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, string, string) error); ok { + r0 = rf(ctx, groupID, orgID, creatorID, creatorType) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MembershipService_OnGroupCreated_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'OnGroupCreated' +type MembershipService_OnGroupCreated_Call struct { + *mock.Call +} + +// OnGroupCreated is a helper method to define mock.On call +// - ctx context.Context +// - groupID string +// - orgID string +// - creatorID string +// - creatorType string +func (_e *MembershipService_Expecter) OnGroupCreated(ctx interface{}, groupID interface{}, orgID interface{}, creatorID interface{}, creatorType interface{}) *MembershipService_OnGroupCreated_Call { + return &MembershipService_OnGroupCreated_Call{Call: _e.mock.On("OnGroupCreated", ctx, groupID, orgID, creatorID, creatorType)} +} + +func (_c *MembershipService_OnGroupCreated_Call) Run(run func(ctx context.Context, groupID string, orgID string, creatorID string, creatorType string)) *MembershipService_OnGroupCreated_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(string), args[4].(string)) + }) + return _c +} + +func (_c *MembershipService_OnGroupCreated_Call) Return(_a0 error) *MembershipService_OnGroupCreated_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MembershipService_OnGroupCreated_Call) RunAndReturn(run func(context.Context, string, string, string, string) error) *MembershipService_OnGroupCreated_Call { + _c.Call.Return(run) + return _c +} + +// NewMembershipService creates a new instance of MembershipService. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMembershipService(t interface { + mock.TestingT + Cleanup(func()) +}) *MembershipService { + mock := &MembershipService{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/core/group/service.go b/core/group/service.go index cad606489..54cea3876 100644 --- a/core/group/service.go +++ b/core/group/service.go @@ -38,11 +38,16 @@ type PolicyService interface { GroupMemberCount(ctx context.Context, ids []string) ([]policy.MemberCount, error) } +type MembershipService interface { + OnGroupCreated(ctx context.Context, groupID, orgID, creatorID, creatorType string) error +} + type Service struct { - repository Repository - relationService RelationService - authnService AuthnService - policyService PolicyService + repository Repository + relationService RelationService + authnService AuthnService + policyService PolicyService + membershipService MembershipService } func NewService(repository Repository, relationService RelationService, @@ -55,6 +60,12 @@ func NewService(repository Repository, relationService RelationService, } } +// SetMembershipService sets the membership dependency after construction to break +// the circular init order between group and membership services. +func (s *Service) SetMembershipService(ms MembershipService) { + s.membershipService = ms +} + func (s Service) Create(ctx context.Context, grp Group) (Group, error) { principal, err := s.authnService.GetPrincipal(ctx) if err != nil { @@ -66,17 +77,7 @@ func (s Service) Create(ctx context.Context, grp Group) (Group, error) { return Group{}, err } - // attach group to org - if err = s.addAsOrgMember(ctx, newGroup); err != nil { - return Group{}, err - } - // add relationship between group to org - if err = s.addOrgToGroup(ctx, newGroup); err != nil { - return Group{}, err - } - - // attach current user to group as owner - if err = s.addOwner(ctx, newGroup.ID, principal); err != nil { + if err = s.membershipService.OnGroupCreated(ctx, newGroup.ID, newGroup.OrganizationID, principal.ID, principal.Type); err != nil { return Group{}, err } @@ -190,36 +191,6 @@ func (s Service) AddMember(ctx context.Context, groupID string, principal authen return nil } -// addOwner adds a user as an owner of group by creating a policy of owner role and an owner relation -func (s Service) addOwner(ctx context.Context, groupID string, principal authenticate.Principal) error { - pol := policy.Policy{ - RoleID: schema.GroupOwnerRole, - ResourceID: groupID, - ResourceType: schema.GroupNamespace, - PrincipalID: principal.ID, - PrincipalType: principal.Type, - } - if _, err := s.policyService.Create(ctx, pol); err != nil { - return err - } - // then create a relation between group and user - rel := relation.Relation{ - Object: relation.Object{ - ID: groupID, - Namespace: schema.GroupNamespace, - }, - Subject: relation.Subject{ - ID: principal.ID, - Namespace: principal.Type, - }, - RelationName: schema.OwnerRelationName, - } - if _, err := s.relationService.Create(ctx, rel); err != nil { - return err - } - return nil -} - // add a policy to user as member of group func (s Service) addMemberPolicy(ctx context.Context, groupID string, principal authenticate.Principal) error { pol := policy.Policy{ @@ -235,51 +206,6 @@ func (s Service) addMemberPolicy(ctx context.Context, groupID string, principal return nil } -// addOrgToGroup creates an inverse relation that connects group to org -func (s Service) addOrgToGroup(ctx context.Context, team Group) error { - rel := relation.Relation{ - Object: relation.Object{ - ID: team.ID, - Namespace: schema.GroupNamespace, - }, - Subject: relation.Subject{ - ID: team.OrganizationID, - Namespace: schema.OrganizationNamespace, - }, - RelationName: schema.OrganizationRelationName, - } - - _, err := s.relationService.Create(ctx, rel) - if err != nil { - return err - } - - return nil -} - -// addAsOrgMember connects group as a member to org -func (s Service) addAsOrgMember(ctx context.Context, team Group) error { - rel := relation.Relation{ - Object: relation.Object{ - ID: team.OrganizationID, - Namespace: schema.OrganizationNamespace, - }, - Subject: relation.Subject{ - ID: team.ID, - Namespace: schema.GroupNamespace, - SubRelationName: schema.MemberRelationName, - }, - RelationName: schema.MemberRelationName, - } - - _, err := s.relationService.Create(ctx, rel) - if err != nil { - return err - } - - return nil -} - // ListByOrganization will be useful for nested groups but we don't do that at the moment // so it will not be directly used func (s Service) ListByOrganization(ctx context.Context, id string) ([]Group, error) { diff --git a/core/group/service_test.go b/core/group/service_test.go index 1df317482..130477dd3 100644 --- a/core/group/service_test.go +++ b/core/group/service_test.go @@ -21,21 +21,21 @@ import ( ) func TestService_Create(t *testing.T) { - t.Run("should create group successfully by adding member to org, adding relation between group and org, and making current user owner", func(t *testing.T) { + t.Run("should create group and delegate hierarchy + owner wiring to membership", func(t *testing.T) { mockRepo := mocks.NewRepository(t) mockAuthnSvc := mocks.NewAuthnService(t) mockRelationSvc := mocks.NewRelationService(t) mockPolicySvc := mocks.NewPolicyService(t) + mockMembershipSvc := mocks.NewMembershipService(t) svc := group.NewService(mockRepo, mockRelationSvc, mockAuthnSvc, mockPolicySvc) + svc.SetMembershipService(mockMembershipSvc) - mockUserID := uuid.New() + mockUserID := uuid.New().String() mockAuthnSvc.On("GetPrincipal", mock.Anything).Return(authenticate.Principal{ - ID: mockUserID.String(), - Type: "user", - User: &user.User{ - ID: mockUserID.String(), - }, + ID: mockUserID, + Type: schema.UserPrincipal, + User: &user.User{ID: mockUserID}, }, nil) groupParam := group.Group{ @@ -43,53 +43,13 @@ func TestService_Create(t *testing.T) { Title: "Test Group", OrganizationID: uuid.New().String(), } - groupInRepo := groupParam groupInRepo.ID = uuid.New().String() mockRepo.On("Create", mock.Anything, groupParam).Return(groupInRepo, nil) - // when adding group as org member - mockRelationSvc.On("Create", mock.Anything, mock.AnythingOfType("relation.Relation")).Run(func(args mock.Arguments) { - arg := args.Get(1) - r := arg.(relation.Relation) - assert.Equal(t, r.Object.ID, groupInRepo.OrganizationID) - assert.Equal(t, r.Subject.ID, groupInRepo.ID) - assert.Equal(t, r.RelationName, schema.MemberRelationName) - }).Return(relation.Relation{}, nil).Once() - - // when adding group to org - mockRelationSvc.On("Create", mock.Anything, mock.AnythingOfType("relation.Relation")).Run(func(args mock.Arguments) { - arg := args.Get(1) - r := arg.(relation.Relation) - assert.Equal(t, r.Object.ID, groupInRepo.ID) - assert.Equal(t, r.Subject.ID, groupInRepo.OrganizationID) - assert.Equal(t, r.RelationName, schema.OrganizationRelationName) - }).Return(relation.Relation{}, nil).Once() - - // when adding current user as group owner - mockPolicySvc.On("Create", mock.Anything, mock.AnythingOfType("policy.Policy")).Run(func(args mock.Arguments) { - arg := args.Get(1) - r := arg.(policy.Policy) - assert.Equal(t, r.RoleID, schema.GroupOwnerRole) - assert.Equal(t, r.ResourceID, groupInRepo.ID) - assert.Equal(t, r.ResourceType, schema.GroupNamespace) - assert.Equal(t, r.PrincipalID, mockUserID.String()) - assert.Equal(t, r.PrincipalType, "user") - }).Return(policy.Policy{}, nil).Once() - - // adding relation between group and user - mockRelationSvc.On("Create", mock.Anything, mock.AnythingOfType("relation.Relation")).Run(func(args mock.Arguments) { - arg := args.Get(1) - r := arg.(relation.Relation) - assert.Equal(t, r.Object.ID, groupInRepo.ID) - assert.Equal(t, r.Object.Namespace, schema.GroupNamespace) - assert.Equal(t, r.Subject.ID, mockUserID.String()) - assert.Equal(t, r.Subject.Namespace, "user") - assert.Equal(t, r.RelationName, schema.OwnerRelationName) - }).Return(relation.Relation{}, nil).Once() + mockMembershipSvc.EXPECT().OnGroupCreated(mock.Anything, groupInRepo.ID, groupInRepo.OrganizationID, mockUserID, schema.UserPrincipal).Return(nil) grp, err := svc.Create(context.Background(), groupParam) - assert.Nil(t, err) assert.Equal(t, grp.Name, groupParam.Name) }) @@ -108,6 +68,33 @@ func TestService_Create(t *testing.T) { assert.NotNil(t, err) assert.Equal(t, strings.Contains(err.Error(), authenticate.ErrInvalidID.Error()), true) }) + + t.Run("should propagate error from membership.OnGroupCreated", func(t *testing.T) { + mockRepo := mocks.NewRepository(t) + mockAuthnSvc := mocks.NewAuthnService(t) + mockRelationSvc := mocks.NewRelationService(t) + mockPolicySvc := mocks.NewPolicyService(t) + mockMembershipSvc := mocks.NewMembershipService(t) + + svc := group.NewService(mockRepo, mockRelationSvc, mockAuthnSvc, mockPolicySvc) + svc.SetMembershipService(mockMembershipSvc) + + mockUserID := uuid.New().String() + mockAuthnSvc.On("GetPrincipal", mock.Anything).Return(authenticate.Principal{ + ID: mockUserID, + Type: schema.UserPrincipal, + User: &user.User{ID: mockUserID}, + }, nil) + + groupParam := group.Group{Name: "g", OrganizationID: uuid.New().String()} + groupInRepo := groupParam + groupInRepo.ID = uuid.New().String() + mockRepo.On("Create", mock.Anything, groupParam).Return(groupInRepo, nil) + mockMembershipSvc.EXPECT().OnGroupCreated(mock.Anything, groupInRepo.ID, groupInRepo.OrganizationID, mockUserID, schema.UserPrincipal).Return(errors.New("spicedb down")) + + _, err := svc.Create(context.Background(), groupParam) + assert.ErrorContains(t, err, "spicedb down") + }) } func TestService_Get(t *testing.T) {