diff --git a/pkg/component/list_based_registry.go b/pkg/component/list_based_registry.go index b5c843fe04..cf1b248613 100644 --- a/pkg/component/list_based_registry.go +++ b/pkg/component/list_based_registry.go @@ -58,6 +58,26 @@ func (r *listBasedRegistry[PA, P, V, MD]) MustRegister( } } +// GetAll implements PredicateBasedRegistry. +func (r *listBasedRegistry[PA, P, V, MD]) GetAll( + ctx context.Context, + arg PA, +) ([]PredicateBasedRegistration[PA, P, V, MD], error) { + r.mu.RLock() + defer r.mu.RUnlock() + var matches []PredicateBasedRegistration[PA, P, V, MD] + for _, reg := range r.registrations { + res, err := reg.Predicate(ctx, arg) + if err != nil { + return nil, err + } + if res { + matches = append(matches, reg) + } + } + return matches, nil +} + // Get implements PredicateBasedRegistry. func (r *listBasedRegistry[PA, P, V, MD]) Get( ctx context.Context, diff --git a/pkg/component/list_based_registry_test.go b/pkg/component/list_based_registry_test.go index 839016d27c..586b2cfa4e 100644 --- a/pkg/component/list_based_registry_test.go +++ b/pkg/component/list_based_registry_test.go @@ -182,6 +182,94 @@ func TestListBasedRegistry_Get(t *testing.T) { } } +func TestListBasedRegistry_GetAll(t *testing.T) { + testCases := []struct { + name string + registry *predicateRegistry + assertions func(*testing.T, []predicateRegistration, error) + }{ + { + name: "error evaluating predicate", + registry: &predicateRegistry{ + registrations: []predicateRegistration{{ + Predicate: func(context.Context, string) (bool, error) { + return false, errors.New("something went wrong") + }, + }}, + }, + assertions: func(t *testing.T, regs []predicateRegistration, err error) { + require.ErrorContains(t, err, "something went wrong") + require.Nil(t, regs) + }, + }, + { + name: "no matches", + registry: &predicateRegistry{}, + assertions: func(t *testing.T, regs []predicateRegistration, err error) { + require.NoError(t, err) + require.Empty(t, regs) + }, + }, + { + name: "one match", + registry: &predicateRegistry{ + registrations: []predicateRegistration{ + { + Predicate: func(context.Context, string) (bool, error) { + return true, nil + }, + Value: "match", + Metadata: "meta1", + }, + { + Predicate: func(context.Context, string) (bool, error) { + return false, nil + }, + Value: "no-match", + Metadata: "meta2", + }, + }, + }, + assertions: func(t *testing.T, regs []predicateRegistration, err error) { + require.NoError(t, err) + require.Len(t, regs, 1) + require.Equal(t, "match", regs[0].Value) + }, + }, + { + name: "multiple matches", + registry: &predicateRegistry{ + registrations: []predicateRegistration{ + { + Predicate: func(context.Context, string) (bool, error) { + return true, nil + }, + Value: "first", + }, + { + Predicate: func(context.Context, string) (bool, error) { + return true, nil + }, + Value: "second", + }, + }, + }, + assertions: func(t *testing.T, regs []predicateRegistration, err error) { + require.NoError(t, err) + require.Len(t, regs, 2) + require.Equal(t, "first", regs[0].Value) + require.Equal(t, "second", regs[1].Value) + }, + }, + } + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + regs, err := testCase.registry.GetAll(t.Context(), "test") + testCase.assertions(t, regs, err) + }) + } +} + func TestListBasedRegistry_WithFunctionValues(t *testing.T) { // Test with function types to ensure the registry works with factory // functions diff --git a/pkg/component/predicate_based_registry.go b/pkg/component/predicate_based_registry.go index 5ed7c9aeb4..6fa9bc035e 100644 --- a/pkg/component/predicate_based_registry.go +++ b/pkg/component/predicate_based_registry.go @@ -42,6 +42,13 @@ type PredicateBasedRegistry[PA any, P Predicate[PA], V, MD any] interface { PredicateBasedRegistration[PA, P, V, MD], error, ) + // GetAll searches for all matching registrations by evaluating each + // registration's predicate against the provided input. Returns all matching + // registrations, or an empty slice if none match. + GetAll(context.Context, PA) ( + []PredicateBasedRegistration[PA, P, V, MD], + error, + ) } // NewPredicateBasedRegistry returns a default implementation of the diff --git a/pkg/controller/management/projects/projects.go b/pkg/controller/management/projects/projects.go index 743949437a..52cf397023 100644 --- a/pkg/controller/management/projects/projects.go +++ b/pkg/controller/management/projects/projects.go @@ -1023,6 +1023,20 @@ func (r *reconciler) ensureDefaultUserRoles( }, }, } + for i, role := range roles { + contribs, err := defaultRoleRulesContributorRegistry.GetAll(ctx, role.Name) + if err != nil { + return fmt.Errorf( + "error getting role rules contributors for role %q: %w", + role.Name, err, + ) + } + for _, contrib := range contribs { + if extra := contrib.Value(role.Name); len(extra) > 0 { + roles[i].Rules = append(roles[i].Rules, extra...) + } + } + } for _, role := range roles { roleLogger := logger.WithValues( "name", role.Name, diff --git a/pkg/controller/management/projects/projects_test.go b/pkg/controller/management/projects/projects_test.go index 25fc9ca88d..dbc25f75cc 100644 --- a/pkg/controller/management/projects/projects_test.go +++ b/pkg/controller/management/projects/projects_test.go @@ -21,6 +21,7 @@ import ( rbacapi "github.com/akuity/kargo/api/rbac/v1alpha1" kargoapi "github.com/akuity/kargo/api/v1alpha1" + "github.com/akuity/kargo/pkg/component" "github.com/akuity/kargo/pkg/conditions" "github.com/akuity/kargo/pkg/kubernetes" ) @@ -2007,6 +2008,144 @@ func TestReconciler_ensureDefaultUserRoles(t *testing.T) { } } +func TestReconciler_ensureDefaultUserRoles_contributors(t *testing.T) { + // Save and restore the global registry around each sub-test. + origRegistry := defaultRoleRulesContributorRegistry + + testCases := []struct { + name string + setup func() + assertions func(*testing.T, []*rbacv1.Role, error) + }{ + { + name: "contributor predicate error propagates", + setup: func() { + defaultRoleRulesContributorRegistry = + component.MustNewPredicateBasedRegistry[ + string, + roleRulesContributorPredicate, + roleRulesContributorFunc, + struct{}, + ](RoleRulesContributorRegistration{ + Predicate: func(context.Context, string) (bool, error) { + return false, errors.New("something went wrong") + }, + Value: func(string) []rbacv1.PolicyRule { return nil }, + }) + }, + assertions: func(t *testing.T, _ []*rbacv1.Role, err error) { + require.ErrorContains(t, err, "error getting role rules contributors") + require.ErrorContains(t, err, "something went wrong") + }, + }, + { + name: "contributor rules are appended to matching roles", + setup: func() { + defaultRoleRulesContributorRegistry = + component.MustNewPredicateBasedRegistry[ + string, + roleRulesContributorPredicate, + roleRulesContributorFunc, + struct{}, + ](RoleRulesContributorRegistration{ + Predicate: func(_ context.Context, roleName string) (bool, error) { + return roleName == "kargo-admin", nil + }, + Value: func(string) []rbacv1.PolicyRule { + return []rbacv1.PolicyRule{{ + APIGroups: []string{"ee.kargo.akuity.io"}, + Resources: []string{"messagechannels"}, + Verbs: []string{"*"}, + }} + }, + }) + }, + assertions: func(t *testing.T, createdRoles []*rbacv1.Role, err error) { + require.NoError(t, err) + var adminRole *rbacv1.Role + var viewerRole *rbacv1.Role + for _, r := range createdRoles { + switch r.Name { + case "kargo-admin": + adminRole = r + case "kargo-viewer": + viewerRole = r + } + } + require.NotNil(t, adminRole) + require.NotNil(t, viewerRole) + + // Admin role should contain the EE rule as the last entry. + lastRule := adminRole.Rules[len(adminRole.Rules)-1] + require.Equal(t, []string{"ee.kargo.akuity.io"}, lastRule.APIGroups) + require.Equal(t, []string{"messagechannels"}, lastRule.Resources) + require.Equal(t, []string{"*"}, lastRule.Verbs) + + // Viewer role should not contain the EE rule. + for _, rule := range viewerRole.Rules { + for _, apiGroup := range rule.APIGroups { + require.NotEqual(t, "ee.kargo.akuity.io", apiGroup) + } + } + }, + }, + } + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + t.Cleanup(func() { + defaultRoleRulesContributorRegistry = origRegistry + }) + testCase.setup() + + var createdRoles []*rbacv1.Role + r := &reconciler{ + createServiceAccountFn: func( + context.Context, + client.Object, + ...client.CreateOption, + ) error { + return apierrors.NewAlreadyExists(schema.GroupResource{}, "") + }, + createRoleFn: func( + _ context.Context, + obj client.Object, + _ ...client.CreateOption, + ) error { + role, ok := obj.(*rbacv1.Role) + require.True(t, ok) + createdRoles = append(createdRoles, role) + return nil + }, + createRoleBindingFn: func( + context.Context, + client.Object, + ...client.CreateOption, + ) error { + return apierrors.NewAlreadyExists(schema.GroupResource{}, "") + }, + createClusterRoleFn: func( + context.Context, + client.Object, + ...client.CreateOption, + ) error { + return apierrors.NewAlreadyExists(schema.GroupResource{}, "") + }, + createClusterRoleBindingFn: func( + context.Context, + client.Object, + ...client.CreateOption, + ) error { + return apierrors.NewAlreadyExists(schema.GroupResource{}, "") + }, + } + p := &kargoapi.Project{ + ObjectMeta: metav1.ObjectMeta{Name: "test-project"}, + } + testCase.assertions(t, createdRoles, r.ensureDefaultUserRoles(t.Context(), p)) + }) + } +} + func TestReconciler_ensureExtendedPermissions(t *testing.T) { testProject := &kargoapi.Project{ ObjectMeta: metav1.ObjectMeta{ diff --git a/pkg/controller/management/projects/role_rules_contributor.go b/pkg/controller/management/projects/role_rules_contributor.go new file mode 100644 index 0000000000..23121885f9 --- /dev/null +++ b/pkg/controller/management/projects/role_rules_contributor.go @@ -0,0 +1,42 @@ +package projects + +import ( + "context" + + rbacv1 "k8s.io/api/rbac/v1" + + "github.com/akuity/kargo/pkg/component" +) + +type ( + // roleRulesContributorPredicate returns true if the contributor has + // PolicyRules to contribute for the given role name. + roleRulesContributorPredicate = func(context.Context, string) (bool, error) + + // roleRulesContributorFunc returns additional PolicyRules for a given role + // name. + roleRulesContributorFunc = func(roleName string) []rbacv1.PolicyRule + + // RoleRulesContributorRegistration associates a predicate with a + // contributor function. + RoleRulesContributorRegistration = component.PredicateBasedRegistration[ + string, + roleRulesContributorPredicate, + roleRulesContributorFunc, + struct{}, + ] +) + +var defaultRoleRulesContributorRegistry = component.MustNewPredicateBasedRegistry[ + string, + roleRulesContributorPredicate, + roleRulesContributorFunc, + struct{}, +]() + +// RegisterRoleRulesContributor adds a contributor to the global registry used +// by the project reconciler when creating default project roles. It should be +// called before SetupReconcilerWithManager (e.g. at program startup). +func RegisterRoleRulesContributor(reg RoleRulesContributorRegistration) { + defaultRoleRulesContributorRegistry.MustRegister(reg) +}