From 737bf16043f3d32c25148c56094197449bacdfe0 Mon Sep 17 00:00:00 2001 From: Illia Litovchenko Date: Thu, 19 Feb 2026 10:12:40 +0000 Subject: [PATCH 1/3] Mocks for k8s client/node drainer --- internal/k8s/kubernetes.go | 43 +++++--- internal/k8s/mock/kubernetes.go | 170 ++++++++++++++++++++++++++++++ internal/nodes/mock/kubernetes.go | 67 ++++++++++++ internal/nodes/node_drainer.go | 14 +-- 4 files changed, 276 insertions(+), 18 deletions(-) create mode 100644 internal/k8s/mock/kubernetes.go create mode 100644 internal/nodes/mock/kubernetes.go diff --git a/internal/k8s/kubernetes.go b/internal/k8s/kubernetes.go index d677bbe9..b8d5ac65 100644 --- a/internal/k8s/kubernetes.go +++ b/internal/k8s/kubernetes.go @@ -1,3 +1,5 @@ +//go:generate mockgen -destination ./mock/kubernetes.go . Client + package k8s import ( @@ -44,32 +46,49 @@ const ( DefaultMaxRetriesK8SOperation = 5 ) +type Client interface { + PatchNode(ctx context.Context, node *v1.Node, changeFn func(*v1.Node)) error + PatchNodeStatus(ctx context.Context, name string, patch []byte) error + EvictPod(ctx context.Context, pod v1.Pod, podEvictRetryDelay time.Duration, version schema.GroupVersion) error + CordonNode(ctx context.Context, node *v1.Node) error + GetNodeByIDs(ctx context.Context, nodeName, nodeID, providerID string) (*v1.Node, error) + ExecuteBatchPodActions( + ctx context.Context, + pods []*v1.Pod, + action func(context.Context, v1.Pod) error, + actionName string, + ) ([]*v1.Pod, []PodActionFailure) + DeletePod(ctx context.Context, options metav1.DeleteOptions, pod v1.Pod, podDeleteRetries int, podDeleteRetryDelay time.Duration) error + Clientset() kubernetes.Interface + Log() logrus.FieldLogger +} + // Client provides Kubernetes operations with common dependencies. -type Client struct { +type client struct { clientset kubernetes.Interface log logrus.FieldLogger } // NewClient creates a new K8s client with the given dependencies. -func NewClient(clientset kubernetes.Interface, log logrus.FieldLogger) *Client { - return &Client{ +func NewClient(clientset kubernetes.Interface, log logrus.FieldLogger) Client { + return &client{ clientset: clientset, log: log, } } // Clientset returns the underlying kubernetes.Interface. -func (c *Client) Clientset() kubernetes.Interface { +func (c *client) Clientset() kubernetes.Interface { return c.clientset } // Log returns the logger. -func (c *Client) Log() logrus.FieldLogger { +func (c *client) Log() logrus.FieldLogger { return c.log } // PatchNode patches a node with the given change function. -func (c *Client) PatchNode(ctx context.Context, node *v1.Node, changeFn func(*v1.Node)) error { +func (c *client) PatchNode(ctx context.Context, node *v1.Node, changeFn func(*v1.Node)) error { logger := logger.FromContext(ctx, c.log) oldData, err := json.Marshal(node) if err != nil { @@ -108,7 +127,7 @@ func (c *Client) PatchNode(ctx context.Context, node *v1.Node, changeFn func(*v1 } // PatchNodeStatus patches the status of a node. -func (c *Client) PatchNodeStatus(ctx context.Context, name string, patch []byte) error { +func (c *client) PatchNodeStatus(ctx context.Context, name string, patch []byte) error { logger := logger.FromContext(ctx, c.log) err := waitext.Retry( @@ -134,7 +153,7 @@ func (c *Client) PatchNodeStatus(ctx context.Context, name string, patch []byte) return nil } -func (c *Client) CordonNode(ctx context.Context, node *v1.Node) error { +func (c *client) CordonNode(ctx context.Context, node *v1.Node) error { if node.Spec.Unschedulable { return nil } @@ -149,7 +168,7 @@ func (c *Client) CordonNode(ctx context.Context, node *v1.Node) error { } // GetNodeByIDs retrieves a node by name and validates its ID and provider ID. -func (c *Client) GetNodeByIDs(ctx context.Context, nodeName, nodeID, providerID string) (*v1.Node, error) { +func (c *client) GetNodeByIDs(ctx context.Context, nodeName, nodeID, providerID string) (*v1.Node, error) { if nodeID == "" && providerID == "" { return nil, fmt.Errorf("node and provider IDs are empty %w", ErrAction) } @@ -178,7 +197,7 @@ func (c *Client) GetNodeByIDs(ctx context.Context, nodeName, nodeID, providerID // It does internal throttling to avoid spawning a goroutine-per-pod on large lists. // Returns two sets of pods - the ones that successfully executed the action and the ones that failed. // actionName might be used to distinguish what is the operation (for logs, debugging, etc.) but is optional. -func (c *Client) ExecuteBatchPodActions( +func (c *client) ExecuteBatchPodActions( ctx context.Context, pods []*v1.Pod, action func(context.Context, v1.Pod) error, @@ -250,7 +269,7 @@ func (c *Client) ExecuteBatchPodActions( // EvictPod evicts a pod from a k8s node. Error handling is based on eviction api documentation: // https://kubernetes.io/docs/tasks/administer-cluster/safely-drain-node/#the-eviction-api -func (c *Client) EvictPod(ctx context.Context, pod v1.Pod, podEvictRetryDelay time.Duration, version schema.GroupVersion) error { +func (c *client) EvictPod(ctx context.Context, pod v1.Pod, podEvictRetryDelay time.Duration, version schema.GroupVersion) error { logger := logger.FromContext(ctx, c.log) b := waitext.NewConstantBackoff(podEvictRetryDelay) @@ -306,7 +325,7 @@ func (c *Client) EvictPod(ctx context.Context, pod v1.Pod, podEvictRetryDelay ti } // DeletePod deletes a pod from the cluster. -func (c *Client) DeletePod(ctx context.Context, options metav1.DeleteOptions, pod v1.Pod, podDeleteRetries int, podDeleteRetryDelay time.Duration) error { +func (c *client) DeletePod(ctx context.Context, options metav1.DeleteOptions, pod v1.Pod, podDeleteRetries int, podDeleteRetryDelay time.Duration) error { logger := logger.FromContext(ctx, c.log) b := waitext.NewConstantBackoff(podDeleteRetryDelay) diff --git a/internal/k8s/mock/kubernetes.go b/internal/k8s/mock/kubernetes.go new file mode 100644 index 00000000..a5c2f3dd --- /dev/null +++ b/internal/k8s/mock/kubernetes.go @@ -0,0 +1,170 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/castai/cluster-controller/internal/k8s (interfaces: Client) + +// Package mock_k8s is a generated GoMock package. +package mock_k8s + +import ( + context "context" + reflect "reflect" + time "time" + + k8s "github.com/castai/cluster-controller/internal/k8s" + gomock "github.com/golang/mock/gomock" + logrus "github.com/sirupsen/logrus" + v1 "k8s.io/api/core/v1" + v10 "k8s.io/apimachinery/pkg/apis/meta/v1" + schema "k8s.io/apimachinery/pkg/runtime/schema" + kubernetes "k8s.io/client-go/kubernetes" +) + +// MockClient is a mock of Client interface. +type MockClient struct { + ctrl *gomock.Controller + recorder *MockClientMockRecorder +} + +// MockClientMockRecorder is the mock recorder for MockClient. +type MockClientMockRecorder struct { + mock *MockClient +} + +// NewMockClient creates a new mock instance. +func NewMockClient(ctrl *gomock.Controller) *MockClient { + mock := &MockClient{ctrl: ctrl} + mock.recorder = &MockClientMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockClient) EXPECT() *MockClientMockRecorder { + return m.recorder +} + +// Clientset mocks base method. +func (m *MockClient) Clientset() kubernetes.Interface { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Clientset") + ret0, _ := ret[0].(kubernetes.Interface) + return ret0 +} + +// Clientset indicates an expected call of Clientset. +func (mr *MockClientMockRecorder) Clientset() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Clientset", reflect.TypeOf((*MockClient)(nil).Clientset)) +} + +// CordonNode mocks base method. +func (m *MockClient) CordonNode(arg0 context.Context, arg1 *v1.Node) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CordonNode", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// CordonNode indicates an expected call of CordonNode. +func (mr *MockClientMockRecorder) CordonNode(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CordonNode", reflect.TypeOf((*MockClient)(nil).CordonNode), arg0, arg1) +} + +// DeletePod mocks base method. +func (m *MockClient) DeletePod(arg0 context.Context, arg1 v10.DeleteOptions, arg2 v1.Pod, arg3 int, arg4 time.Duration) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeletePod", arg0, arg1, arg2, arg3, arg4) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeletePod indicates an expected call of DeletePod. +func (mr *MockClientMockRecorder) DeletePod(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeletePod", reflect.TypeOf((*MockClient)(nil).DeletePod), arg0, arg1, arg2, arg3, arg4) +} + +// EvictPod mocks base method. +func (m *MockClient) EvictPod(arg0 context.Context, arg1 v1.Pod, arg2 time.Duration, arg3 schema.GroupVersion) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "EvictPod", arg0, arg1, arg2, arg3) + ret0, _ := ret[0].(error) + return ret0 +} + +// EvictPod indicates an expected call of EvictPod. +func (mr *MockClientMockRecorder) EvictPod(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EvictPod", reflect.TypeOf((*MockClient)(nil).EvictPod), arg0, arg1, arg2, arg3) +} + +// ExecuteBatchPodActions mocks base method. +func (m *MockClient) ExecuteBatchPodActions(arg0 context.Context, arg1 []*v1.Pod, arg2 func(context.Context, v1.Pod) error, arg3 string) ([]*v1.Pod, []k8s.PodActionFailure) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ExecuteBatchPodActions", arg0, arg1, arg2, arg3) + ret0, _ := ret[0].([]*v1.Pod) + ret1, _ := ret[1].([]k8s.PodActionFailure) + return ret0, ret1 +} + +// ExecuteBatchPodActions indicates an expected call of ExecuteBatchPodActions. +func (mr *MockClientMockRecorder) ExecuteBatchPodActions(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ExecuteBatchPodActions", reflect.TypeOf((*MockClient)(nil).ExecuteBatchPodActions), arg0, arg1, arg2, arg3) +} + +// GetNodeByIDs mocks base method. +func (m *MockClient) GetNodeByIDs(arg0 context.Context, arg1, arg2, arg3 string) (*v1.Node, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetNodeByIDs", arg0, arg1, arg2, arg3) + ret0, _ := ret[0].(*v1.Node) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetNodeByIDs indicates an expected call of GetNodeByIDs. +func (mr *MockClientMockRecorder) GetNodeByIDs(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetNodeByIDs", reflect.TypeOf((*MockClient)(nil).GetNodeByIDs), arg0, arg1, arg2, arg3) +} + +// Log mocks base method. +func (m *MockClient) Log() logrus.FieldLogger { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Log") + ret0, _ := ret[0].(logrus.FieldLogger) + return ret0 +} + +// Log indicates an expected call of Log. +func (mr *MockClientMockRecorder) Log() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Log", reflect.TypeOf((*MockClient)(nil).Log)) +} + +// PatchNode mocks base method. +func (m *MockClient) PatchNode(arg0 context.Context, arg1 *v1.Node, arg2 func(*v1.Node)) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "PatchNode", arg0, arg1, arg2) + ret0, _ := ret[0].(error) + return ret0 +} + +// PatchNode indicates an expected call of PatchNode. +func (mr *MockClientMockRecorder) PatchNode(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PatchNode", reflect.TypeOf((*MockClient)(nil).PatchNode), arg0, arg1, arg2) +} + +// PatchNodeStatus mocks base method. +func (m *MockClient) PatchNodeStatus(arg0 context.Context, arg1 string, arg2 []byte) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "PatchNodeStatus", arg0, arg1, arg2) + ret0, _ := ret[0].(error) + return ret0 +} + +// PatchNodeStatus indicates an expected call of PatchNodeStatus. +func (mr *MockClientMockRecorder) PatchNodeStatus(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PatchNodeStatus", reflect.TypeOf((*MockClient)(nil).PatchNodeStatus), arg0, arg1, arg2) +} diff --git a/internal/nodes/mock/kubernetes.go b/internal/nodes/mock/kubernetes.go new file mode 100644 index 00000000..5a739075 --- /dev/null +++ b/internal/nodes/mock/kubernetes.go @@ -0,0 +1,67 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/castai/cluster-controller/internal/nodes (interfaces: Drainer) + +// Package mock_nodes is a generated GoMock package. +package mock_nodes + +import ( + context "context" + reflect "reflect" + + nodes "github.com/castai/cluster-controller/internal/nodes" + gomock "github.com/golang/mock/gomock" + v1 "k8s.io/api/core/v1" +) + +// MockDrainer is a mock of Drainer interface. +type MockDrainer struct { + ctrl *gomock.Controller + recorder *MockDrainerMockRecorder +} + +// MockDrainerMockRecorder is the mock recorder for MockDrainer. +type MockDrainerMockRecorder struct { + mock *MockDrainer +} + +// NewMockDrainer creates a new mock instance. +func NewMockDrainer(ctrl *gomock.Controller) *MockDrainer { + mock := &MockDrainer{ctrl: ctrl} + mock.recorder = &MockDrainerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockDrainer) EXPECT() *MockDrainerMockRecorder { + return m.recorder +} + +// Drain mocks base method. +func (m *MockDrainer) Drain(arg0 context.Context, arg1 nodes.DrainRequest) ([]*v1.Pod, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Drain", arg0, arg1) + ret0, _ := ret[0].([]*v1.Pod) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Drain indicates an expected call of Drain. +func (mr *MockDrainerMockRecorder) Drain(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Drain", reflect.TypeOf((*MockDrainer)(nil).Drain), arg0, arg1) +} + +// Evict mocks base method. +func (m *MockDrainer) Evict(arg0 context.Context, arg1 nodes.EvictRequest) ([]*v1.Pod, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Evict", arg0, arg1) + ret0, _ := ret[0].([]*v1.Pod) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Evict indicates an expected call of Evict. +func (mr *MockDrainerMockRecorder) Evict(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Evict", reflect.TypeOf((*MockDrainer)(nil).Evict), arg0, arg1) +} diff --git a/internal/nodes/node_drainer.go b/internal/nodes/node_drainer.go index 55b888f4..2bf4029a 100644 --- a/internal/nodes/node_drainer.go +++ b/internal/nodes/node_drainer.go @@ -1,3 +1,5 @@ +//go:generate mockgen -destination ./mock/kubernetes.go . Drainer + package nodes import ( @@ -44,14 +46,14 @@ type DrainerConfig struct { type drainer struct { pods informer.PodInformer - client *k8s.Client + client k8s.Client cfg DrainerConfig log logrus.FieldLogger } func NewDrainer( pods informer.PodInformer, - client *k8s.Client, + client k8s.Client, log logrus.FieldLogger, cfg DrainerConfig, ) Drainer { @@ -75,7 +77,7 @@ func (d *drainer) Drain(ctx context.Context, data DrainRequest) ([]*core.Pod, er toEvict := d.prioritizePods(pods, data.CastNamespace, data.SkipDeletedTimeoutSeconds) if len(toEvict) == 0 { - return []*core.Pod{}, nil + return nil, nil } _, failed, err := d.tryDrain(ctx, toEvict, data.DeleteOptions) @@ -85,7 +87,7 @@ func (d *drainer) Drain(ctx context.Context, data DrainRequest) ([]*core.Pod, er err = d.waitTerminaition(ctx, data.Node, failed) if err != nil { - return []*core.Pod{}, err + return nil, err } logger.Info("drain finished") @@ -115,7 +117,7 @@ func (d *drainer) Evict(ctx context.Context, data EvictRequest) ([]*core.Pod, er toEvict := d.prioritizePods(pods, data.CastNamespace, data.SkipDeletedTimeoutSeconds) if len(toEvict) == 0 { - return []*core.Pod{}, nil + return nil, nil } _, ignored, err := d.tryEvict(ctx, toEvict) @@ -125,7 +127,7 @@ func (d *drainer) Evict(ctx context.Context, data EvictRequest) ([]*core.Pod, er err = d.waitTerminaition(ctx, data.Node, ignored) if err != nil { - return []*core.Pod{}, err + return nil, err } logger.Info("eviction finished") From 02b6cdbb9023a0a828b71ba0be4683836aff5451 Mon Sep 17 00:00:00 2001 From: Illia Litovchenko Date: Thu, 19 Feb 2026 11:50:14 +0000 Subject: [PATCH 2/3] drain node informer added --- .../actions/drain_node_informers_handler.go | 313 +++++++++++ .../drain_node_informers_handler_test.go | 501 ++++++++++++++++++ 2 files changed, 814 insertions(+) create mode 100644 internal/actions/drain_node_informers_handler.go create mode 100644 internal/actions/drain_node_informers_handler_test.go diff --git a/internal/actions/drain_node_informers_handler.go b/internal/actions/drain_node_informers_handler.go new file mode 100644 index 00000000..f97af726 --- /dev/null +++ b/internal/actions/drain_node_informers_handler.go @@ -0,0 +1,313 @@ +package actions + +import ( + "context" + "errors" + "fmt" + "reflect" + "time" + + "github.com/sirupsen/logrus" + v1 "k8s.io/api/core/v1" + k8serrors "k8s.io/apimachinery/pkg/api/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/client-go/kubernetes" + + "github.com/castai/cluster-controller/internal/castai" + "github.com/castai/cluster-controller/internal/informer" + "github.com/castai/cluster-controller/internal/k8s" + "github.com/castai/cluster-controller/internal/logger" + "github.com/castai/cluster-controller/internal/nodes" + "github.com/castai/cluster-controller/internal/volume" +) + +var _ ActionHandler = &DrainNodeInfomerHandler{} + +const ( + defaultPodsDeleteTimeout = 2 * time.Minute + defaultPodDeleteRetries = 5 + defaultPodDeleteRetryDelay = 5 * time.Second + defaultPodEvictRetryDelay = 5 * time.Second + defaultPodsTerminationWaitRetryDelay = 10 * time.Second + defaultSkipDeletedTimeoutSeconds = 60 +) + +func newDefaultDrainNodeConfig(castNamespace string) drainNodeConfig { + return drainNodeConfig{ + podsDeleteTimeout: defaultPodsDeleteTimeout, + podDeleteRetries: defaultPodDeleteRetries, + podDeleteRetryDelay: defaultPodDeleteRetryDelay, + podEvictRetryDelay: defaultPodEvictRetryDelay, + podsTerminationWaitRetryDelay: defaultPodsTerminationWaitRetryDelay, + castNamespace: castNamespace, + skipDeletedTimeoutSeconds: defaultSkipDeletedTimeoutSeconds, + } +} + +func NewDrainNodeInformerHandler( + log logrus.FieldLogger, + clientset kubernetes.Interface, + castNamespace string, + vaWaiter volume.DetachmentWaiter, + podInformer informer.PodInformer, + nodeInformer informer.NodeInformer, +) *DrainNodeInfomerHandler { + client := k8s.NewClient(clientset, log) + nodeManager := nodes.NewDrainer(podInformer, client, log, nodes.DrainerConfig{ + PodEvictRetryDelay: defaultPodEvictRetryDelay, + PodsTerminationWaitRetryDelay: defaultPodDeleteRetryDelay, + PodDeleteRetries: defaultPodDeleteRetries, + }) + + return &DrainNodeInfomerHandler{ + log: log, + vaWaiter: vaWaiter, + cfg: newDefaultDrainNodeConfig(castNamespace), + nodeManager: nodeManager, + nodeInformer: nodeInformer, + client: client, + } +} + +type DrainNodeInfomerHandler struct { + log logrus.FieldLogger + vaWaiter volume.DetachmentWaiter + cfg drainNodeConfig + nodeManager nodes.Drainer + nodeInformer informer.NodeInformer + client k8s.Client +} + +func (h *DrainNodeInfomerHandler) Handle(ctx context.Context, action *castai.ClusterAction) error { + req, err := h.validateAction(action) + if err != nil { + return err + } + + log := h.createDrainNodeLogger(action, req) + log.Info("draining kubernetes node") + + ctx = logger.WithLogger(ctx, log) + drainTimeout := k8s.GetDrainTimeout(action) + + node, err := h.getAndValidateNode(ctx, req) + if err != nil { + return err + } + if node == nil { + return nil + } + + if err = h.cordonNode(ctx, node); err != nil { + return err + } + + log.Infof("draining node, drain_timeout_seconds=%f, force=%v created_at=%s", drainTimeout.Seconds(), req.Force, action.CreatedAt) + + return h.drainNode(ctx, node.Name, req, drainTimeout) +} + +func (h *DrainNodeInfomerHandler) drainNode(ctx context.Context, nodeName string, req *castai.ActionDrainNode, drainTimeout time.Duration) error { + log := logger.FromContext(ctx, h.log) + + nonEvictablePods, err := h.tryEviction(ctx, nodeName, drainTimeout) + if err == nil { + log.Info("node fully drained via graceful eviction") + h.waitForVolumeDetachIfEnabled(ctx, nodeName, req, nonEvictablePods) + return nil + } + + if !req.Force { + return fmt.Errorf("node failed to drain via graceful eviction, force=%v, timeout=%f, will not force delete pods: %w", req.Force, drainTimeout.Seconds(), err) + } + + if !h.shouldForceDrain(ctx, err, drainTimeout, req.Force) { + return fmt.Errorf("evicting node pods: %w", err) + } + + nonEvictablePods, drainErr := h.forceDrain(ctx, nodeName) + if drainErr == nil { + log.Info("node drained forcefully") + h.waitForVolumeDetachIfEnabled(ctx, nodeName, req, nonEvictablePods) + } else { + log.Warnf("node failed to fully force drain: %v", drainErr) + } + + return drainErr +} + +func (h *DrainNodeInfomerHandler) validateAction(action *castai.ClusterAction) (*castai.ActionDrainNode, error) { + if action == nil { + return nil, fmt.Errorf("action is nil %w", k8s.ErrAction) + } + + req, ok := action.Data().(*castai.ActionDrainNode) + if !ok { + return nil, newUnexpectedTypeErr(action.Data(), req) + } + + if req.NodeName == "" || (req.NodeID == "" && req.ProviderId == "") { + return nil, fmt.Errorf("node name or node ID/provider ID is empty %w", k8s.ErrAction) + } + + return req, nil +} + +func (h *DrainNodeInfomerHandler) createDrainNodeLogger(action *castai.ClusterAction, req *castai.ActionDrainNode) logrus.FieldLogger { + return h.log.WithFields(logrus.Fields{ + "node_name": req.NodeName, + "node_id": req.NodeID, + "provider_id": req.ProviderId, + "action": reflect.TypeOf(action.Data().(*castai.ActionDrainNode)).String(), + ActionIDLogField: action.ID, + }) +} + +func (h *DrainNodeInfomerHandler) getAndValidateNode(ctx context.Context, req *castai.ActionDrainNode) (*v1.Node, error) { + log := logger.FromContext(ctx, h.log) + + // Try to get node from informer cache first + node, err := h.nodeInformer.Get(req.NodeName) + if err != nil { + if k8serrors.IsNotFound(err) { + // Fallback to API if not in cache + return h.getNodeFromAPI(ctx, req) + } + return nil, err + } + + if node == nil { + log.Info("node not found, skipping draining") + return nil, nil + } + + if err := k8s.IsNodeIDProviderIDValid(node, req.NodeID, req.ProviderId); err != nil { + if errors.Is(err, k8s.ErrNodeDoesNotMatch) { + log.Info("node does not match expected IDs, skipping draining") + return nil, nil + } + return nil, err + } + + return node, nil +} + +func (h *DrainNodeInfomerHandler) getNodeFromAPI(ctx context.Context, req *castai.ActionDrainNode) (*v1.Node, error) { + log := logger.FromContext(ctx, h.log) + log.Debug("node not found in cache, fetching directly from API") + + node, err := h.client.GetNodeByIDs(ctx, req.NodeName, req.NodeID, req.ProviderId) + if err != nil { + if errors.Is(err, k8s.ErrNodeNotFound) { + log.Info("node not found in API, skipping draining") + return nil, nil + } + if errors.Is(err, k8s.ErrNodeDoesNotMatch) { + log.Info("node does not match expected IDs, skipping draining") + return nil, nil + } + return nil, fmt.Errorf("failed to get node from API: %w", err) + } + + return node, nil +} + +func (h *DrainNodeInfomerHandler) cordonNode(ctx context.Context, node *v1.Node) error { + log := logger.FromContext(ctx, h.log) + log.Info("cordoning node for draining") + if err := h.client.CordonNode(ctx, node); err != nil { + return fmt.Errorf("cordoning node %q: %w", node.Name, err) + } + return nil +} + +func (h *DrainNodeInfomerHandler) tryEviction(ctx context.Context, nodeName string, timeout time.Duration) ([]*v1.Pod, error) { + evictCtx, evictCancel := context.WithTimeout(ctx, timeout) + defer evictCancel() + + return h.nodeManager.Evict(evictCtx, nodes.EvictRequest{ + Node: nodeName, + SkipDeletedTimeoutSeconds: h.cfg.skipDeletedTimeoutSeconds, + CastNamespace: h.cfg.castNamespace, + }) +} + +func (h *DrainNodeInfomerHandler) shouldForceDrain(ctx context.Context, evictionErr error, drainTimeout time.Duration, force bool) bool { + log := logger.FromContext(ctx, h.log) + + // Check if error is recoverable through force drain + var podsFailedEvictionErr *k8s.PodFailedActionError + + if errors.Is(evictionErr, context.DeadlineExceeded) { + log.Infof("eviction timeout=%f exceeded, force=%v, proceeding with force drain", drainTimeout.Seconds(), force) + return true + } + + if errors.As(evictionErr, &podsFailedEvictionErr) { + log.Infof("some pods failed eviction, force=%v, proceeding with force drain: %v", force, evictionErr) + return true + } + + // Unrecoverable errors (e.g., missing permissions, connectivity issues) + return false +} + +func (h *DrainNodeInfomerHandler) forceDrain(ctx context.Context, nodeName string) ([]*v1.Pod, error) { + deleteOptions := []metav1.DeleteOptions{ + {}, + *metav1.NewDeleteOptions(0), + } + + var nonEvictablePods []*v1.Pod + var lastErr error + + for _, opts := range deleteOptions { + deleteCtx, cancel := context.WithTimeout(ctx, h.cfg.podsDeleteTimeout) + defer cancel() + + nonEvictablePods, lastErr = h.nodeManager.Drain(deleteCtx, nodes.DrainRequest{ + Node: nodeName, + CastNamespace: h.cfg.castNamespace, + SkipDeletedTimeoutSeconds: h.cfg.skipDeletedTimeoutSeconds, + DeleteOptions: opts, + }) + + if lastErr == nil { + return nonEvictablePods, nil + } + + var podsFailedDeletionErr *k8s.PodFailedActionError + if errors.Is(lastErr, context.DeadlineExceeded) || errors.As(lastErr, &podsFailedDeletionErr) { + continue + } + + return nil, fmt.Errorf("forcefully deleting pods: %w", lastErr) + } + + return nonEvictablePods, lastErr +} + +// waitForVolumeDetachIfEnabled waits for VolumeAttachments to be deleted if the feature is enabled. +// This is called after successful drain to give CSI drivers time to clean up volumes. +// nonEvictablePods are pods that won't be evicted (DaemonSet, static) - their volumes are excluded from waiting. +func (h *DrainNodeInfomerHandler) waitForVolumeDetachIfEnabled(ctx context.Context, nodeName string, req *castai.ActionDrainNode, nonEvictablePods []*v1.Pod) { + if !ShouldWaitForVolumeDetach(req) || h.vaWaiter == nil { + return + } + + log := logger.FromContext(ctx, h.log) + + var timeout time.Duration + if req.VolumeDetachTimeoutSeconds != nil && *req.VolumeDetachTimeoutSeconds > 0 { + timeout = time.Duration(*req.VolumeDetachTimeoutSeconds) * time.Second + } + + err := h.vaWaiter.Wait(ctx, log, volume.DetachmentWaitOptions{ + NodeName: nodeName, + Timeout: timeout, + PodsToExclude: nonEvictablePods, + }) + if err != nil { + log.Warnf("error waiting for volume detachment: %v", err) + } +} diff --git a/internal/actions/drain_node_informers_handler_test.go b/internal/actions/drain_node_informers_handler_test.go new file mode 100644 index 00000000..bcbf0c3d --- /dev/null +++ b/internal/actions/drain_node_informers_handler_test.go @@ -0,0 +1,501 @@ +package actions + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/golang/mock/gomock" + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/require" + v1 "k8s.io/api/core/v1" + k8serrors "k8s.io/apimachinery/pkg/api/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + "github.com/castai/cluster-controller/internal/castai" + "github.com/castai/cluster-controller/internal/informer" + "github.com/castai/cluster-controller/internal/k8s" + mock_k8s "github.com/castai/cluster-controller/internal/k8s/mock" + mock_nodes "github.com/castai/cluster-controller/internal/nodes/mock" + "github.com/castai/cluster-controller/internal/volume" +) + +// stubNodeInformer is a simple test implementation of informer.NodeInformer. +type stubNodeInformer struct { + node *v1.Node + err error +} + +func (s *stubNodeInformer) Get(_ string) (*v1.Node, error) { + return s.node, s.err +} + +func (s *stubNodeInformer) List() ([]*v1.Node, error) { + return nil, nil +} + +func (s *stubNodeInformer) Wait(_ context.Context, _ string, _ informer.Predicate) chan error { + return make(chan error, 1) +} + +func newDrainTestNode() *v1.Node { + return &v1.Node{ + ObjectMeta: metav1.ObjectMeta{ + Name: nodeName, + Labels: map[string]string{ + castai.LabelNodeID: nodeID, + }, + }, + Spec: v1.NodeSpec{ + ProviderID: providerID, + }, + } +} + +// nolint +func newActionDrainNodeWithVolumeDetach(name, nID, pID string, drainTimeoutSeconds int, force bool) *castai.ClusterAction { + action := newActionDrainNode(name, nID, pID, drainTimeoutSeconds, force) + waitForVA := true + action.ActionDrainNode.WaitForVolumeDetach = &waitForVA + return action +} + +func TestDrainNodeInformerHandler_Handle(t *testing.T) { + t.Parallel() + + type setupFn func(mockClient *mock_k8s.MockClient, mockDrainer *mock_nodes.MockDrainer) + + tests := []struct { + name string + action *castai.ClusterAction + cfg drainNodeConfig + nodeInformer informer.NodeInformer + vaWaiter volume.DetachmentWaiter + setup setupFn + wantErrIs error + wantErrorContains string + wantVolumeWait bool + }{ + { + name: "nil action returns error", + nodeInformer: &stubNodeInformer{}, + wantErrIs: k8s.ErrAction, + }, + { + name: "wrong action type returns error", + action: &castai.ClusterAction{ + ActionDeleteNode: &castai.ActionDeleteNode{}, + }, + nodeInformer: &stubNodeInformer{}, + wantErrIs: k8s.ErrAction, + }, + { + name: "empty node name returns error", + action: newActionDrainNode("", nodeID, providerID, 1, true), + nodeInformer: &stubNodeInformer{}, + wantErrIs: k8s.ErrAction, + }, + { + name: "empty node ID and provider ID returns error", + action: newActionDrainNode(nodeName, "", "", 1, true), + nodeInformer: &stubNodeInformer{}, + wantErrIs: k8s.ErrAction, + }, + { + name: "node found in informer cache but both IDs do not match, skip drain", + action: newActionDrainNode(nodeName, "another-node-id", "another-provider-id", 1, true), + nodeInformer: &stubNodeInformer{ + node: newDrainTestNode(), + }, + }, + { + name: "node found in informer cache, nodeID matches but providerID does not, skip drain", + action: newActionDrainNode(nodeName, nodeID, "another-provider-id", 1, true), + nodeInformer: &stubNodeInformer{ + node: newDrainTestNode(), + }, + }, + { + name: "node found in informer cache, providerID matches but nodeID does not, skip drain", + action: newActionDrainNode(nodeName, "another-node-id", providerID, 1, true), + nodeInformer: &stubNodeInformer{ + node: newDrainTestNode(), + }, + }, + { + name: "informer returns unexpected non-NotFound error, propagate error", + action: newActionDrainNode(nodeName, nodeID, providerID, 1, true), + nodeInformer: &stubNodeInformer{ + err: errors.New("internal informer error"), + }, + wantErrorContains: "internal informer error", + }, + { + name: "informer returns nil node with no error, skip drain", + action: newActionDrainNode(nodeName, nodeID, providerID, 1, true), + nodeInformer: &stubNodeInformer{}, + }, + { + name: "node not found in informer, fallback to API returns not found, skip drain", + action: newActionDrainNode(nodeName, nodeID, providerID, 1, true), + nodeInformer: &stubNodeInformer{ + err: k8serrors.NewNotFound(v1.Resource("nodes"), nodeName), + }, + setup: func(mockClient *mock_k8s.MockClient, _ *mock_nodes.MockDrainer) { + mockClient.EXPECT(). + GetNodeByIDs(gomock.Any(), nodeName, nodeID, providerID). + Return(nil, k8s.ErrNodeNotFound) + }, + }, + { + name: "node not found in informer, fallback to API IDs do not match, skip drain", + action: newActionDrainNode(nodeName, nodeID, providerID, 1, true), + nodeInformer: &stubNodeInformer{ + err: k8serrors.NewNotFound(v1.Resource("nodes"), nodeName), + }, + setup: func(mockClient *mock_k8s.MockClient, _ *mock_nodes.MockDrainer) { + mockClient.EXPECT(). + GetNodeByIDs(gomock.Any(), nodeName, nodeID, providerID). + Return(nil, k8s.ErrNodeDoesNotMatch) + }, + }, + { + name: "node not found in informer, API returns unexpected error, propagate error", + action: newActionDrainNode(nodeName, nodeID, providerID, 1, true), + nodeInformer: &stubNodeInformer{ + err: k8serrors.NewNotFound(v1.Resource("nodes"), nodeName), + }, + setup: func(mockClient *mock_k8s.MockClient, _ *mock_nodes.MockDrainer) { + mockClient.EXPECT(). + GetNodeByIDs(gomock.Any(), nodeName, nodeID, providerID). + Return(nil, errors.New("api connectivity error")) + }, + wantErrorContains: "failed to get node from API", + }, + { + name: "node not found in informer, found via API fallback, drain succeeds", + action: newActionDrainNode(nodeName, nodeID, providerID, 10, true), + nodeInformer: &stubNodeInformer{ + err: k8serrors.NewNotFound(v1.Resource("nodes"), nodeName), + }, + cfg: drainNodeConfig{ + podsDeleteTimeout: 10 * time.Second, + }, + setup: func(mockClient *mock_k8s.MockClient, mockDrainer *mock_nodes.MockDrainer) { + mockClient.EXPECT(). + GetNodeByIDs(gomock.Any(), nodeName, nodeID, providerID). + Return(newDrainTestNode(), nil) + mockClient.EXPECT().CordonNode(gomock.Any(), gomock.Any()).Return(nil) + mockDrainer.EXPECT().Evict(gomock.Any(), gomock.Any()).Return(nil, nil) + }, + }, + { + name: "cordon node fails, return error", + action: newActionDrainNode(nodeName, nodeID, providerID, 10, true), + nodeInformer: &stubNodeInformer{ + node: newDrainTestNode(), + }, + setup: func(mockClient *mock_k8s.MockClient, _ *mock_nodes.MockDrainer) { + mockClient.EXPECT().CordonNode(gomock.Any(), gomock.Any()). + Return(errors.New("cordon failed")) + }, + wantErrorContains: "cordoning node", + }, + { + name: "drain node successfully via eviction", + action: newActionDrainNode(nodeName, nodeID, providerID, 10, true), + nodeInformer: &stubNodeInformer{ + node: newDrainTestNode(), + }, + cfg: drainNodeConfig{ + podsDeleteTimeout: 10 * time.Second, + }, + setup: func(mockClient *mock_k8s.MockClient, mockDrainer *mock_nodes.MockDrainer) { + mockClient.EXPECT().CordonNode(gomock.Any(), gomock.Any()).Return(nil) + mockDrainer.EXPECT().Evict(gomock.Any(), gomock.Any()).Return(nil, nil) + }, + }, + { + name: "drain node successfully via eviction with volume detach wait", + action: newActionDrainNodeWithVolumeDetach(nodeName, nodeID, providerID, 10, true), + nodeInformer: &stubNodeInformer{ + node: newDrainTestNode(), + }, + cfg: drainNodeConfig{ + podsDeleteTimeout: 10 * time.Second, + }, + vaWaiter: &mockVolumeDetachmentWaiter{}, + setup: func(mockClient *mock_k8s.MockClient, mockDrainer *mock_nodes.MockDrainer) { + mockClient.EXPECT().CordonNode(gomock.Any(), gomock.Any()).Return(nil) + mockDrainer.EXPECT().Evict(gomock.Any(), gomock.Any()).Return(nil, nil) + }, + wantVolumeWait: true, + }, + { + name: "eviction fails with pod failure, force=false returns error", + action: newActionDrainNode(nodeName, nodeID, providerID, 10, false), + nodeInformer: &stubNodeInformer{ + node: newDrainTestNode(), + }, + setup: func(mockClient *mock_k8s.MockClient, mockDrainer *mock_nodes.MockDrainer) { + mockClient.EXPECT().CordonNode(gomock.Any(), gomock.Any()).Return(nil) + mockDrainer.EXPECT().Evict(gomock.Any(), gomock.Any()). + Return(nil, &k8s.PodFailedActionError{Action: "evict", Errors: []error{errors.New("evict failed")}}) + }, + wantErrorContains: "node failed to drain via graceful eviction", + }, + { + name: "eviction timeout, force=false returns error", + action: newActionDrainNode(nodeName, nodeID, providerID, 0, false), + nodeInformer: &stubNodeInformer{ + node: newDrainTestNode(), + }, + setup: func(mockClient *mock_k8s.MockClient, mockDrainer *mock_nodes.MockDrainer) { + mockClient.EXPECT().CordonNode(gomock.Any(), gomock.Any()).Return(nil) + mockDrainer.EXPECT().Evict(gomock.Any(), gomock.Any()). + Return(nil, context.DeadlineExceeded) + }, + wantErrIs: context.DeadlineExceeded, + wantErrorContains: "node failed to drain via graceful eviction", + }, + { + name: "eviction timeout, force=true, force drain succeeds", + action: newActionDrainNode(nodeName, nodeID, providerID, 0, true), + nodeInformer: &stubNodeInformer{ + node: newDrainTestNode(), + }, + cfg: drainNodeConfig{ + podsDeleteTimeout: 10 * time.Second, + }, + setup: func(mockClient *mock_k8s.MockClient, mockDrainer *mock_nodes.MockDrainer) { + mockClient.EXPECT().CordonNode(gomock.Any(), gomock.Any()).Return(nil) + mockDrainer.EXPECT().Evict(gomock.Any(), gomock.Any()). + Return(nil, context.DeadlineExceeded) + mockDrainer.EXPECT().Drain(gomock.Any(), gomock.Any()).Return(nil, nil) + }, + }, + { + name: "eviction fails with pod failure, force=true, force drain succeeds", + action: newActionDrainNode(nodeName, nodeID, providerID, 10, true), + nodeInformer: &stubNodeInformer{ + node: newDrainTestNode(), + }, + cfg: drainNodeConfig{ + podsDeleteTimeout: 10 * time.Second, + }, + setup: func(mockClient *mock_k8s.MockClient, mockDrainer *mock_nodes.MockDrainer) { + mockClient.EXPECT().CordonNode(gomock.Any(), gomock.Any()).Return(nil) + mockDrainer.EXPECT().Evict(gomock.Any(), gomock.Any()). + Return(nil, &k8s.PodFailedActionError{Action: "evict"}) + mockDrainer.EXPECT().Drain(gomock.Any(), gomock.Any()).Return(nil, nil) + }, + }, + { + name: "force drain both attempts time out, return error", + action: newActionDrainNode(nodeName, nodeID, providerID, 0, true), + nodeInformer: &stubNodeInformer{ + node: newDrainTestNode(), + }, + cfg: drainNodeConfig{ + podsDeleteTimeout: 10 * time.Second, + }, + setup: func(mockClient *mock_k8s.MockClient, mockDrainer *mock_nodes.MockDrainer) { + mockClient.EXPECT().CordonNode(gomock.Any(), gomock.Any()).Return(nil) + mockDrainer.EXPECT().Evict(gomock.Any(), gomock.Any()). + Return(nil, context.DeadlineExceeded) + first := mockDrainer.EXPECT().Drain(gomock.Any(), gomock.Any()). + Return(nil, context.DeadlineExceeded) + second := mockDrainer.EXPECT().Drain(gomock.Any(), gomock.Any()). + Return(nil, context.DeadlineExceeded) + gomock.InOrder(first, second) + }, + wantErrIs: context.DeadlineExceeded, + }, + { + name: "force drain first attempt fails with pod deletion failure, second attempt succeeds", + action: newActionDrainNode(nodeName, nodeID, providerID, 0, true), + nodeInformer: &stubNodeInformer{ + node: newDrainTestNode(), + }, + cfg: drainNodeConfig{ + podsDeleteTimeout: 10 * time.Second, + }, + setup: func(mockClient *mock_k8s.MockClient, mockDrainer *mock_nodes.MockDrainer) { + mockClient.EXPECT().CordonNode(gomock.Any(), gomock.Any()).Return(nil) + mockDrainer.EXPECT().Evict(gomock.Any(), gomock.Any()). + Return(nil, context.DeadlineExceeded) + first := mockDrainer.EXPECT().Drain(gomock.Any(), gomock.Any()). + Return(nil, &k8s.PodFailedActionError{Action: "delete"}) + second := mockDrainer.EXPECT().Drain(gomock.Any(), gomock.Any()). + Return(nil, nil) + gomock.InOrder(first, second) + }, + }, + { + name: "force drain first attempt times out, second attempt with forced grace period succeeds", + action: newActionDrainNode(nodeName, nodeID, providerID, 0, true), + nodeInformer: &stubNodeInformer{ + node: newDrainTestNode(), + }, + cfg: drainNodeConfig{ + podsDeleteTimeout: 10 * time.Second, + }, + setup: func(mockClient *mock_k8s.MockClient, mockDrainer *mock_nodes.MockDrainer) { + mockClient.EXPECT().CordonNode(gomock.Any(), gomock.Any()).Return(nil) + mockDrainer.EXPECT().Evict(gomock.Any(), gomock.Any()). + Return(nil, context.DeadlineExceeded) + first := mockDrainer.EXPECT().Drain(gomock.Any(), gomock.Any()). + Return(nil, context.DeadlineExceeded) + second := mockDrainer.EXPECT().Drain(gomock.Any(), gomock.Any()). + Return(nil, nil) + gomock.InOrder(first, second) + }, + }, + { + name: "eviction returns unrecoverable error, force=true does not proceed to force drain", + action: newActionDrainNode(nodeName, nodeID, providerID, 10, true), + nodeInformer: &stubNodeInformer{ + node: newDrainTestNode(), + }, + setup: func(mockClient *mock_k8s.MockClient, mockDrainer *mock_nodes.MockDrainer) { + mockClient.EXPECT().CordonNode(gomock.Any(), gomock.Any()).Return(nil) + mockDrainer.EXPECT().Evict(gomock.Any(), gomock.Any()). + Return(nil, errors.New("connection refused")) + }, + wantErrorContains: "evicting node pods", + }, + { + name: "force drain fails with unrecoverable error returns error", + action: newActionDrainNode(nodeName, nodeID, providerID, 0, true), + nodeInformer: &stubNodeInformer{ + node: newDrainTestNode(), + }, + cfg: drainNodeConfig{ + podsDeleteTimeout: 10 * time.Second, + }, + setup: func(mockClient *mock_k8s.MockClient, mockDrainer *mock_nodes.MockDrainer) { + mockClient.EXPECT().CordonNode(gomock.Any(), gomock.Any()).Return(nil) + mockDrainer.EXPECT().Evict(gomock.Any(), gomock.Any()). + Return(nil, context.DeadlineExceeded) + mockDrainer.EXPECT().Drain(gomock.Any(), gomock.Any()). + Return(nil, errors.New("internal error")) + }, + wantErrorContains: "forcefully deleting pods", + }, + { + name: "force drain succeeds with volume detach wait enabled", + action: newActionDrainNodeWithVolumeDetach(nodeName, nodeID, providerID, 0, true), + nodeInformer: &stubNodeInformer{ + node: newDrainTestNode(), + }, + cfg: drainNodeConfig{ + podsDeleteTimeout: 10 * time.Second, + }, + vaWaiter: &mockVolumeDetachmentWaiter{}, + setup: func(mockClient *mock_k8s.MockClient, mockDrainer *mock_nodes.MockDrainer) { + mockClient.EXPECT().CordonNode(gomock.Any(), gomock.Any()).Return(nil) + mockDrainer.EXPECT().Evict(gomock.Any(), gomock.Any()). + Return(nil, context.DeadlineExceeded) + mockDrainer.EXPECT().Drain(gomock.Any(), gomock.Any()).Return(nil, nil) + }, + wantVolumeWait: true, + }, + { + name: "volume detach wait not called when vaWaiter is nil even if enabled in request", + action: newActionDrainNodeWithVolumeDetach(nodeName, nodeID, providerID, 10, true), + nodeInformer: &stubNodeInformer{ + node: newDrainTestNode(), + }, + cfg: drainNodeConfig{ + podsDeleteTimeout: 10 * time.Second, + }, + // vaWaiter intentionally left nil + setup: func(mockClient *mock_k8s.MockClient, mockDrainer *mock_nodes.MockDrainer) { + mockClient.EXPECT().CordonNode(gomock.Any(), gomock.Any()).Return(nil) + mockDrainer.EXPECT().Evict(gomock.Any(), gomock.Any()).Return(nil, nil) + }, + }, + { + name: "volume detach wait called with custom timeout", + action: func() *castai.ClusterAction { + a := newActionDrainNodeWithVolumeDetach(nodeName, nodeID, providerID, 10, true) + timeout := 120 + a.ActionDrainNode.VolumeDetachTimeoutSeconds = &timeout + return a + }(), + nodeInformer: &stubNodeInformer{ + node: newDrainTestNode(), + }, + cfg: drainNodeConfig{ + podsDeleteTimeout: 10 * time.Second, + }, + vaWaiter: &mockVolumeDetachmentWaiter{}, + setup: func(mockClient *mock_k8s.MockClient, mockDrainer *mock_nodes.MockDrainer) { + mockClient.EXPECT().CordonNode(gomock.Any(), gomock.Any()).Return(nil) + mockDrainer.EXPECT().Evict(gomock.Any(), gomock.Any()).Return(nil, nil) + }, + wantVolumeWait: true, + }, + { + name: "volume detach wait error is logged but not returned to caller", + action: newActionDrainNodeWithVolumeDetach(nodeName, nodeID, providerID, 10, true), + nodeInformer: &stubNodeInformer{ + node: newDrainTestNode(), + }, + cfg: drainNodeConfig{ + podsDeleteTimeout: 10 * time.Second, + }, + vaWaiter: &mockVolumeDetachmentWaiter{waitErr: errors.New("detach failed")}, + setup: func(mockClient *mock_k8s.MockClient, mockDrainer *mock_nodes.MockDrainer) { + mockClient.EXPECT().CordonNode(gomock.Any(), gomock.Any()).Return(nil) + mockDrainer.EXPECT().Evict(gomock.Any(), gomock.Any()).Return(nil, nil) + }, + wantVolumeWait: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + mockClient := mock_k8s.NewMockClient(ctrl) + mockDrainer := mock_nodes.NewMockDrainer(ctrl) + + if tt.setup != nil { + tt.setup(mockClient, mockDrainer) + } + + h := &DrainNodeInfomerHandler{ + log: logrus.New(), + nodeInformer: tt.nodeInformer, + nodeManager: mockDrainer, + client: mockClient, + cfg: tt.cfg, + vaWaiter: tt.vaWaiter, + } + + err := h.Handle(context.Background(), tt.action) + + switch { + case tt.wantErrIs != nil: + require.ErrorIs(t, err, tt.wantErrIs) + if tt.wantErrorContains != "" { + require.ErrorContains(t, err, tt.wantErrorContains) + } + case tt.wantErrorContains != "": + require.Error(t, err) + require.ErrorContains(t, err, tt.wantErrorContains) + default: + require.NoError(t, err) + } + + if tt.wantVolumeWait { + waiter, ok := tt.vaWaiter.(*mockVolumeDetachmentWaiter) + require.True(t, ok) + require.True(t, waiter.waitCalled, "expected volume detachment Wait to be called") + } + }) + } +} From fe6b7bae5c57e22c7afe014f5c952a523cda1d47 Mon Sep 17 00:00:00 2001 From: Illia Litovchenko Date: Thu, 19 Feb 2026 11:53:12 +0000 Subject: [PATCH 3/3] using drain node informer if enabled --- cmd/controller/run.go | 5 +++-- internal/actions/actions.go | 5 +++++ 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/cmd/controller/run.go b/cmd/controller/run.go index 611d1bad..c2c15464 100644 --- a/cmd/controller/run.go +++ b/cmd/controller/run.go @@ -142,7 +142,7 @@ func runController( informerOpts = append(informerOpts, informer.EnableNodeInformer()) } if cfg.Informer.EnablePod { - informerOpts = append(informerOpts, informer.EnablePodInformer()) + informerOpts = append(informerOpts, informer.EnablePodInformer(), informer.WithDefaultPodNodeNameIndexer()) } informerManager := informer.NewManager( @@ -165,7 +165,8 @@ func runController( clientset, dynamicClient, helmClient, - informerManager.GetNodeInformer(), + informerManager.Nodes(), + informerManager.Pods(), vaWaiter, ) diff --git a/internal/actions/actions.go b/internal/actions/actions.go index b0a3a390..a287ab4b 100644 --- a/internal/actions/actions.go +++ b/internal/actions/actions.go @@ -23,6 +23,7 @@ func NewDefaultActionHandlers( dynamicClient dynamic.Interface, helmClient helm.Client, nodeInformer informer.NodeInformer, + podInformer informer.PodInformer, vaWaiter volume.DetachmentWaiter, ) ActionHandlers { handlers := ActionHandlers{ @@ -46,6 +47,10 @@ func NewDefaultActionHandlers( handlers[reflect.TypeFor[*castai.ActionCheckNodeStatus]()] = NewCheckNodeStatusInformerHandler(log, clientset, nodeInformer) } + if podInformer != nil && nodeInformer != nil { + handlers[reflect.TypeFor[*castai.ActionDrainNode]()] = NewDrainNodeInformerHandler(log, clientset, castNamespace, vaWaiter, podInformer, nodeInformer) + } + return handlers }