From eef6d857f36298ec58e6a01999500aa4032f46dd Mon Sep 17 00:00:00 2001 From: Nitish Bhat Date: Wed, 22 Apr 2026 14:59:37 -0700 Subject: [PATCH] Create DeviceClass from operator code on OpenShift when DRA is enabled (#1388) * Create DeviceClass from operator code on OpenShift when DRA is enabled On OpenShift, operator-sdk cannot deploy DeviceClass resources via the OLM bundle. This adds handleDeviceClass to the reconciler which creates the gpu.amd.com DeviceClass using an unstructured client when running on OpenShift with DRA driver enabled. The DeviceClass is cluster-scoped and shared, so it is created once (AlreadyExists is handled gracefully) and never deleted on DeviceConfig finalization. * Use deviceClassName constant instead of hardcoded string Address review feedback: extract "gpu.amd.com" into a const and use it throughout handleDeviceClass. (cherry picked from commit d8845e69298bab8e21f2c0955a4b6579bc79ea76) --- ...md-gpu-operator.clusterserviceversion.yaml | 8 +- config/rbac/role.yaml | 6 ++ .../controllers/device_config_reconciler.go | 59 +++++++++++- .../device_config_reconciler_test.go | 94 +++++++++++++++++-- .../mock_device_config_reconciler.go | 14 +++ 5 files changed, 171 insertions(+), 10 deletions(-) diff --git a/bundle/manifests/amd-gpu-operator.clusterserviceversion.yaml b/bundle/manifests/amd-gpu-operator.clusterserviceversion.yaml index e6cf37ecf..2b0cc77fd 100644 --- a/bundle/manifests/amd-gpu-operator.clusterserviceversion.yaml +++ b/bundle/manifests/amd-gpu-operator.clusterserviceversion.yaml @@ -36,7 +36,7 @@ metadata: capabilities: Seamless Upgrades categories: AI/Machine Learning,Monitoring containerImage: registry.test.pensando.io:5000/amd-gpu-operator:dev - createdAt: "2026-04-06T08:31:30Z" + createdAt: "2026-04-22T01:09:34Z" description: |- Operator responsible for deploying AMD GPU kernel drivers, device plugin, device test runner and device metrics exporter For more information, visit [documentation](https://instinct.docs.amd.com/projects/gpu-operator/en/latest/) @@ -1310,6 +1310,12 @@ spec: verbs: - get - update + - apiGroups: + - resource.k8s.io + resources: + - deviceclasses + verbs: + - create - apiGroups: - argoproj.io resources: diff --git a/config/rbac/role.yaml b/config/rbac/role.yaml index 2d0b992aa..73af8d115 100644 --- a/config/rbac/role.yaml +++ b/config/rbac/role.yaml @@ -197,3 +197,9 @@ rules: verbs: - get - update +- apiGroups: + - resource.k8s.io + resources: + - deviceclasses + verbs: + - create diff --git a/internal/controllers/device_config_reconciler.go b/internal/controllers/device_config_reconciler.go index 8c5722099..3c4c33b25 100644 --- a/internal/controllers/device_config_reconciler.go +++ b/internal/controllers/device_config_reconciler.go @@ -48,7 +48,9 @@ import ( k8serrors "k8s.io/apimachinery/pkg/api/errors" meta "k8s.io/apimachinery/pkg/api/meta" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" "k8s.io/apimachinery/pkg/labels" + "k8s.io/apimachinery/pkg/runtime/schema" "k8s.io/apimachinery/pkg/types" "k8s.io/client-go/rest" "k8s.io/client-go/util/retry" @@ -79,8 +81,15 @@ const ( DeviceConfigReconcilerName = "DriverAndPluginReconciler" deviceConfigFinalizer = "amd.node.kubernetes.io/deviceconfig-finalizer" testRunnerNodeLabelPrefix = "testrunner.amd.com" + deviceClassName = "gpu.amd.com" ) +var draDeviceClassGVK = schema.GroupVersionKind{ + Group: "resource.k8s.io", + Version: "v1", + Kind: "DeviceClass", +} + // ModuleReconciler reconciles a Module object type DeviceConfigReconciler struct { client.Client @@ -108,7 +117,7 @@ func NewDeviceConfigReconciler( kmmWatchEnabled bool) *DeviceConfigReconciler { upgradeMgrHandler := newUpgradeMgrHandler(client, k8sConfig, isOpenShift) remediationMgrHandler := newRemediationMgrHandler(client, apiReader, k8sConfig, isOpenShift) - helper := newDeviceConfigReconcilerHelper(client, kmmHandler, dpHandler, nlHandler, upgradeMgrHandler, remediationMgrHandler, metricsHandler, testrunnerHandler, configmanagerHandler, workerMgr, kmmWatchEnabled) + helper := newDeviceConfigReconcilerHelper(client, kmmHandler, dpHandler, nlHandler, upgradeMgrHandler, remediationMgrHandler, metricsHandler, testrunnerHandler, configmanagerHandler, workerMgr, isOpenShift, kmmWatchEnabled) podEventHandler := watchers.NewPodEventHandler(client, workerMgr) nodeEventHandler := watchers.NewNodeEventHandler(client, workerMgr) daemonsetEventHandler := watchers.NewDaemonsetEventHandler(client) @@ -203,6 +212,7 @@ func (r *DeviceConfigReconciler) init(ctx context.Context) { //+kubebuilder:rbac:groups=core,resources=pods/eviction,verbs=delete;get;list;create //+kubebuilder:rbac:groups=apiextensions.k8s.io,resources=customresourcedefinitions,verbs=get;list;watch;delete //+kubebuilder:rbac:groups=monitoring.coreos.com,resources=servicemonitors,verbs=get;list;watch;create;update;patch;delete +//+kubebuilder:rbac:groups=resource.k8s.io,resources=deviceclasses,verbs=create func (r *DeviceConfigReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, error) { res := ctrl.Result{} @@ -299,6 +309,11 @@ func (r *DeviceConfigReconciler) Reconcile(ctx context.Context, req ctrl.Request return res, fmt.Errorf("failed to handle device-plugin for DeviceConfig %s: %v", req.NamespacedName, err) } + logger.Info("start DeviceClass reconciliation") + if err = r.helper.handleDeviceClass(ctx, devConfig); err != nil { + return res, fmt.Errorf("failed to handle DeviceClass for DeviceConfig %s: %v", req.NamespacedName, err) + } + logger.Info("start dra-driver reconciliation") if err = r.helper.handleDRADriver(ctx, devConfig, nodes); err != nil { return res, fmt.Errorf("failed to handle dra-driver for DeviceConfig %s: %v", req.NamespacedName, err) @@ -374,6 +389,7 @@ type deviceConfigReconcilerHelperAPI interface { setFinalizer(ctx context.Context, devConfig *amdv1alpha1.DeviceConfig) error handleKMMModule(ctx context.Context, devConfig *amdv1alpha1.DeviceConfig, nodes *v1.NodeList) error handleDevicePlugin(ctx context.Context, devConfig *amdv1alpha1.DeviceConfig, nodes *v1.NodeList) error + handleDeviceClass(ctx context.Context, devConfig *amdv1alpha1.DeviceConfig) error handleDRADriver(ctx context.Context, devConfig *amdv1alpha1.DeviceConfig, nodes *v1.NodeList) error handleKMMVersionLabel(ctx context.Context, devConfig *amdv1alpha1.DeviceConfig, nodes *v1.NodeList) error handleBuildConfigMap(ctx context.Context, devConfig *amdv1alpha1.DeviceConfig, nodes *v1.NodeList) error @@ -392,6 +408,7 @@ type deviceConfigReconcilerHelperAPI interface { type deviceConfigReconcilerHelper struct { client client.Client kmmWatchEnabled bool + isOpenShift bool kmmHandler kmmmodule.KMMModuleAPI devicePluginHandler plugin.DevicePluginAPI nlHandler nodelabeller.NodeLabeller @@ -418,12 +435,14 @@ func newDeviceConfigReconcilerHelper(client client.Client, testrunnerHandler testrunner.TestRunner, configmanagerHandler configmanager.ConfigManager, workerMgr workermgr.WorkerMgrAPI, + isOpenShift bool, kmmWatchEnabled bool) deviceConfigReconcilerHelperAPI { conditionUpdater := conditions.NewDeviceConfigConditionMgr() validator := validator.NewValidator() return &deviceConfigReconcilerHelper{ client: client, kmmWatchEnabled: kmmWatchEnabled, + isOpenShift: isOpenShift, kmmHandler: kmmHandler, devicePluginHandler: dpHandler, nlHandler: nlHandler, @@ -1234,6 +1253,44 @@ func (dcrh *deviceConfigReconcilerHelper) handleDRADriver(ctx context.Context, d return nil } +func (dcrh *deviceConfigReconcilerHelper) handleDeviceClass(ctx context.Context, devConfig *amdv1alpha1.DeviceConfig) error { + if !dcrh.isOpenShift { + return nil + } + if !devConfig.Spec.DRADriver.IsEnabled() { + return nil + } + + logger := log.FromContext(ctx) + + dc := &unstructured.Unstructured{} + dc.SetGroupVersionKind(draDeviceClassGVK) + dc.SetName(deviceClassName) + dc.SetLabels(map[string]string{ + "app.kubernetes.io/component": "amd-gpu", + "app.kubernetes.io/part-of": "amd-gpu", + }) + dc.Object["spec"] = map[string]interface{}{ + "selectors": []interface{}{ + map[string]interface{}{ + "cel": map[string]interface{}{ + "expression": "device.driver == '" + deviceClassName + "'", + }, + }, + }, + } + + if err := dcrh.client.Create(ctx, dc); err != nil { + if k8serrors.IsAlreadyExists(err) { + return nil + } + return fmt.Errorf("failed to create DeviceClass %s: %v", deviceClassName, err) + } + + logger.Info("Created DeviceClass", "name", deviceClassName) + return nil +} + func (dcrh *deviceConfigReconcilerHelper) handleKMMVersionLabel(ctx context.Context, devConfig *amdv1alpha1.DeviceConfig, nodes *v1.NodeList) error { // label corresponding node with given kmod version // so that KMM could manage the upgrade by watching the node's version label change diff --git a/internal/controllers/device_config_reconciler_test.go b/internal/controllers/device_config_reconciler_test.go index 754f87466..846b1418d 100644 --- a/internal/controllers/device_config_reconciler_test.go +++ b/internal/controllers/device_config_reconciler_test.go @@ -197,7 +197,7 @@ var _ = Describe("getLabelsPerModules", func() { BeforeEach(func() { ctrl := gomock.NewController(GinkgoT()) kubeClient = mock_client.NewMockClient(ctrl) - dcrh = newDeviceConfigReconcilerHelper(kubeClient, nil, nil, nil, nil, nil, nil, nil, nil, nil, true) + dcrh = newDeviceConfigReconcilerHelper(kubeClient, nil, nil, nil, nil, nil, nil, nil, nil, nil, false, true) }) ctx := context.Background() @@ -241,7 +241,7 @@ var _ = Describe("deviceConfigReconcilerHelper with KMM watch disabled", func() BeforeEach(func() { ctrl := gomock.NewController(GinkgoT()) kubeClient = mock_client.NewMockClient(ctrl) - dcrh = newDeviceConfigReconcilerHelper(kubeClient, nil, nil, nil, nil, nil, nil, nil, nil, nil, true) + dcrh = newDeviceConfigReconcilerHelper(kubeClient, nil, nil, nil, nil, nil, nil, nil, nil, nil, false, true) }) ctx := context.Background() nn := types.NamespacedName{ @@ -282,7 +282,7 @@ var _ = Describe("setFinalizer", func() { BeforeEach(func() { ctrl := gomock.NewController(GinkgoT()) kubeClient = mock_client.NewMockClient(ctrl) - dcrh = newDeviceConfigReconcilerHelper(kubeClient, nil, nil, nil, nil, nil, nil, nil, nil, nil, true) + dcrh = newDeviceConfigReconcilerHelper(kubeClient, nil, nil, nil, nil, nil, nil, nil, nil, nil, false, true) }) ctx := context.Background() @@ -318,7 +318,7 @@ var _ = Describe("finalizeDeviceConfig", func() { BeforeEach(func() { ctrl := gomock.NewController(GinkgoT()) kubeClient = mock_client.NewMockClient(ctrl) - dcrh = newDeviceConfigReconcilerHelper(kubeClient, nil, nil, nil, nil, nil, nil, nil, nil, nil, true) + dcrh = newDeviceConfigReconcilerHelper(kubeClient, nil, nil, nil, nil, nil, nil, nil, nil, nil, false, true) }) ctx := context.Background() @@ -539,7 +539,7 @@ var _ = Describe("handleKMMModule", func() { ctrl := gomock.NewController(GinkgoT()) kubeClient = mock_client.NewMockClient(ctrl) kmmHelper = kmmmodule.NewMockKMMModuleAPI(ctrl) - dcrh = newDeviceConfigReconcilerHelper(kubeClient, kmmHelper, nil, nil, nil, nil, nil, nil, nil, nil, true) + dcrh = newDeviceConfigReconcilerHelper(kubeClient, kmmHelper, nil, nil, nil, nil, nil, nil, nil, nil, false, true) }) ctx := context.Background() @@ -609,7 +609,7 @@ var _ = Describe("handleBuildConfigMap", func() { ctrl := gomock.NewController(GinkgoT()) kubeClient = mock_client.NewMockClient(ctrl) kmmHelper = kmmmodule.NewMockKMMModuleAPI(ctrl) - dcrh = newDeviceConfigReconcilerHelper(kubeClient, kmmHelper, nil, nil, nil, nil, nil, nil, nil, nil, true) + dcrh = newDeviceConfigReconcilerHelper(kubeClient, kmmHelper, nil, nil, nil, nil, nil, nil, nil, nil, false, true) }) ctx := context.Background() @@ -676,7 +676,7 @@ var _ = Describe("handleNodeLabeller", func() { ctrl := gomock.NewController(GinkgoT()) kubeClient = mock_client.NewMockClient(ctrl) nodeLabellerHelper = nodelabeller.NewMockNodeLabeller(ctrl) - dcrh = newDeviceConfigReconcilerHelper(kubeClient, nil, nil, nodeLabellerHelper, nil, nil, nil, nil, nil, nil, true) + dcrh = newDeviceConfigReconcilerHelper(kubeClient, nil, nil, nodeLabellerHelper, nil, nil, nil, nil, nil, nil, false, true) }) ctx := context.Background() @@ -762,7 +762,7 @@ var _ = Describe("buildNodeAssignments", func() { BeforeEach(func() { ctrl := gomock.NewController(GinkgoT()) kubeClient := mock_client.NewMockClient(ctrl) - dcrh = newDeviceConfigReconcilerHelper(kubeClient, nil, nil, nil, nil, nil, nil, nil, nil, nil, true) + dcrh = newDeviceConfigReconcilerHelper(kubeClient, nil, nil, nil, nil, nil, nil, nil, nil, nil, false, true) }) It("skips non-ready DeviceConfigs", func() { @@ -821,3 +821,81 @@ var _ = Describe("buildNodeAssignments", func() { Expect(err).ToNot(HaveOccurred()) }) }) + +var _ = Describe("handleDeviceClass", func() { + var ( + kubeClient *mock_client.MockClient + dcrh deviceConfigReconcilerHelperAPI + ) + + ctx := context.Background() + draEnabled := true + draDisabled := false + + draEnabledConfig := &amdv1alpha1.DeviceConfig{ + ObjectMeta: metav1.ObjectMeta{Name: devConfigName, Namespace: devConfigNamespace}, + Spec: amdv1alpha1.DeviceConfigSpec{ + DRADriver: amdv1alpha1.DRADriverSpec{Enable: &draEnabled}, + }, + } + + draDisabledConfig := &amdv1alpha1.DeviceConfig{ + ObjectMeta: metav1.ObjectMeta{Name: devConfigName, Namespace: devConfigNamespace}, + Spec: amdv1alpha1.DeviceConfigSpec{ + DRADriver: amdv1alpha1.DRADriverSpec{Enable: &draDisabled}, + }, + } + + It("should skip when not on OpenShift", func() { + ctrl := gomock.NewController(GinkgoT()) + kubeClient = mock_client.NewMockClient(ctrl) + dcrh = newDeviceConfigReconcilerHelper(kubeClient, nil, nil, nil, nil, nil, nil, nil, nil, nil, false, true) + + err := dcrh.handleDeviceClass(ctx, draEnabledConfig) + Expect(err).ToNot(HaveOccurred()) + }) + + It("should skip when DRA driver is not enabled", func() { + ctrl := gomock.NewController(GinkgoT()) + kubeClient = mock_client.NewMockClient(ctrl) + dcrh = newDeviceConfigReconcilerHelper(kubeClient, nil, nil, nil, nil, nil, nil, nil, nil, nil, true, true) + + err := dcrh.handleDeviceClass(ctx, draDisabledConfig) + Expect(err).ToNot(HaveOccurred()) + }) + + It("should create DeviceClass when it does not exist", func() { + ctrl := gomock.NewController(GinkgoT()) + kubeClient = mock_client.NewMockClient(ctrl) + dcrh = newDeviceConfigReconcilerHelper(kubeClient, nil, nil, nil, nil, nil, nil, nil, nil, nil, true, true) + + kubeClient.EXPECT().Create(ctx, gomock.Any()).Return(nil) + + err := dcrh.handleDeviceClass(ctx, draEnabledConfig) + Expect(err).ToNot(HaveOccurred()) + }) + + It("should succeed when DeviceClass already exists", func() { + ctrl := gomock.NewController(GinkgoT()) + kubeClient = mock_client.NewMockClient(ctrl) + dcrh = newDeviceConfigReconcilerHelper(kubeClient, nil, nil, nil, nil, nil, nil, nil, nil, nil, true, true) + + kubeClient.EXPECT().Create(ctx, gomock.Any()).Return( + k8serrors.NewAlreadyExists(schema.GroupResource{Group: "resource.k8s.io", Resource: "deviceclasses"}, "gpu.amd.com"), + ) + + err := dcrh.handleDeviceClass(ctx, draEnabledConfig) + Expect(err).ToNot(HaveOccurred()) + }) + + It("should return error when Create fails", func() { + ctrl := gomock.NewController(GinkgoT()) + kubeClient = mock_client.NewMockClient(ctrl) + dcrh = newDeviceConfigReconcilerHelper(kubeClient, nil, nil, nil, nil, nil, nil, nil, nil, nil, true, true) + + kubeClient.EXPECT().Create(ctx, gomock.Any()).Return(fmt.Errorf("server error")) + + err := dcrh.handleDeviceClass(ctx, draEnabledConfig) + Expect(err).To(HaveOccurred()) + }) +}) diff --git a/internal/controllers/mock_device_config_reconciler.go b/internal/controllers/mock_device_config_reconciler.go index 79667bcd5..998b6aa89 100644 --- a/internal/controllers/mock_device_config_reconciler.go +++ b/internal/controllers/mock_device_config_reconciler.go @@ -218,6 +218,20 @@ func (mr *MockdeviceConfigReconcilerHelperAPIMockRecorder) handleDRADriver(ctx, return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handleDRADriver", reflect.TypeOf((*MockdeviceConfigReconcilerHelperAPI)(nil).handleDRADriver), ctx, devConfig, nodes) } +// handleDeviceClass mocks base method. +func (m *MockdeviceConfigReconcilerHelperAPI) handleDeviceClass(ctx context.Context, devConfig *v1alpha1.DeviceConfig) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "handleDeviceClass", ctx, devConfig) + ret0, _ := ret[0].(error) + return ret0 +} + +// handleDeviceClass indicates an expected call of handleDeviceClass. +func (mr *MockdeviceConfigReconcilerHelperAPIMockRecorder) handleDeviceClass(ctx, devConfig any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handleDeviceClass", reflect.TypeOf((*MockdeviceConfigReconcilerHelperAPI)(nil).handleDeviceClass), ctx, devConfig) +} + // handleDevicePlugin mocks base method. func (m *MockdeviceConfigReconcilerHelperAPI) handleDevicePlugin(ctx context.Context, devConfig *v1alpha1.DeviceConfig, nodes *v1.NodeList) error { m.ctrl.T.Helper()