diff --git a/controllers/upgrade_controller.go b/controllers/upgrade_controller.go index 7481239d1..e82e2de35 100644 --- a/controllers/upgrade_controller.go +++ b/controllers/upgrade_controller.go @@ -36,6 +36,7 @@ import ( "sigs.k8s.io/controller-runtime/pkg/reconcile" "sigs.k8s.io/controller-runtime/pkg/source" + upgrade_v1alpha1 "github.com/NVIDIA/k8s-operator-libs/api/upgrade/v1alpha1" "github.com/NVIDIA/k8s-operator-libs/pkg/consts" "github.com/NVIDIA/k8s-operator-libs/pkg/upgrade" "github.com/go-logr/logr" @@ -169,6 +170,9 @@ func (r *UpgradeReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ct // if the operator is evicted and can't be rescheduled to any other node, e.g. in a single-node cluster. // It's safe to do because the goal of the node draining during the upgrade is to // evict pods that might use driver and operator doesn't use in its own pod. + if clusterPolicy.Spec.Driver.UpgradePolicy.DrainSpec == nil { + clusterPolicy.Spec.Driver.UpgradePolicy.DrainSpec = &upgrade_v1alpha1.DrainSpec{} + } if clusterPolicy.Spec.Driver.UpgradePolicy.DrainSpec.PodSelector == "" { clusterPolicy.Spec.Driver.UpgradePolicy.DrainSpec.PodSelector = UpgradeSkipDrainLabelSelector } else { diff --git a/controllers/upgrade_controller_test.go b/controllers/upgrade_controller_test.go new file mode 100644 index 000000000..3b72da082 --- /dev/null +++ b/controllers/upgrade_controller_test.go @@ -0,0 +1,71 @@ +/* + * Copyright (c) NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package controllers + +import ( + "fmt" + "testing" + + upgrade_v1alpha1 "github.com/NVIDIA/k8s-operator-libs/api/upgrade/v1alpha1" + "github.com/stretchr/testify/assert" +) + +func TestSetDrainSpecPodSelector(t *testing.T) { + tests := []struct { + name string + drainSpec *upgrade_v1alpha1.DrainSpec + expectedSelector string + }{ + { + name: "nil DrainSpec should be initialized with default PodSelector", + drainSpec: nil, + expectedSelector: UpgradeSkipDrainLabelSelector, + }, + { + name: "empty PodSelector should be set to default", + drainSpec: &upgrade_v1alpha1.DrainSpec{}, + expectedSelector: UpgradeSkipDrainLabelSelector, + }, + { + name: "existing PodSelector should be appended", + drainSpec: &upgrade_v1alpha1.DrainSpec{PodSelector: "app=myapp"}, + expectedSelector: fmt.Sprintf("app=myapp,%s", UpgradeSkipDrainLabelSelector), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + upgradePolicy := &upgrade_v1alpha1.DriverUpgradePolicySpec{ + AutoUpgrade: true, + DrainSpec: tt.drainSpec, + } + + if upgradePolicy.DrainSpec == nil { + upgradePolicy.DrainSpec = &upgrade_v1alpha1.DrainSpec{} + } + if upgradePolicy.DrainSpec.PodSelector == "" { + upgradePolicy.DrainSpec.PodSelector = UpgradeSkipDrainLabelSelector + } else { + upgradePolicy.DrainSpec.PodSelector = + fmt.Sprintf("%s,%s", upgradePolicy.DrainSpec.PodSelector, UpgradeSkipDrainLabelSelector) + } + + assert.NotNil(t, upgradePolicy.DrainSpec) + assert.Equal(t, tt.expectedSelector, upgradePolicy.DrainSpec.PodSelector) + }) + } +}