From 826f08749d29dfa67223c443a611c618ae542284 Mon Sep 17 00:00:00 2001 From: knave Date: Sun, 28 Sep 2025 19:32:21 +0800 Subject: [PATCH 1/4] fix: only one workload is generated for the Deployment, with its name adjusted based on the root controller --- internal/utils/compose.go | 5 - internal/utils/owner_ref_utils.go | 40 ++++++ internal/utils/owner_ref_utils_test.go | 124 ++++++++++++++++ internal/webhook/v1/pod_webhook.go | 67 +++++---- internal/webhook/v1/pod_webhook_test.go | 184 ++++++++++++++++++++++++ internal/webhook/v1/tf_parser.go | 18 +-- 6 files changed, 397 insertions(+), 41 deletions(-) diff --git a/internal/utils/compose.go b/internal/utils/compose.go index 8802c6ce..cef36f89 100644 --- a/internal/utils/compose.go +++ b/internal/utils/compose.go @@ -78,11 +78,6 @@ type TensorFusionInfo struct { EnabledReplicas *int32 WorkloadName string ContainerNames []string - GenWorkload bool - - // Pod mutating webhook can not get Pod UID sometimes, - // thus need pod controller to set the owner reference - PendingSetPodAsOwner bool } func AddOrOverrideTFClientMissingAnnotationsBeforePatch(pod *v1.Pod, tfInfo TensorFusionInfo) { diff --git a/internal/utils/owner_ref_utils.go b/internal/utils/owner_ref_utils.go index 5a4fe9df..97c8cf4f 100644 --- a/internal/utils/owner_ref_utils.go +++ b/internal/utils/owner_ref_utils.go @@ -96,3 +96,43 @@ func FindFirstLevelOwnerReference(obj metav1.Object) *metav1.OwnerReference { } return &ownerRef } + +// FindRootControllerRef recursively finds the root controller reference for a given object (e.g. Pod). +func FindRootControllerRef(ctx context.Context, c client.Client, obj metav1.Object) (*metav1.OwnerReference, error) { + if metav1.GetControllerOfNoCopy(obj) == nil { + return nil, nil + } + + namespace := obj.GetNamespace() + current := obj + for { + controllerRef := metav1.GetControllerOf(current) + if controllerRef == nil { + if rObj, ok := current.(runtime.Object); ok { + gvk := rObj.GetObjectKind().GroupVersionKind() + return metav1.NewControllerRef(current, gvk), nil + } else { + return nil, fmt.Errorf("not a runtime.Object") + } + } + + unObj := &unstructured.Unstructured{} + unObj.SetAPIVersion(controllerRef.APIVersion) + unObj.SetKind(controllerRef.Kind) + err := c.Get(ctx, client.ObjectKey{Name: controllerRef.Name, Namespace: namespace}, unObj) + if err != nil { + // if not found, return controllerRef as root + if errors.IsNotFound(err) { + return controllerRef, nil + } + return nil, fmt.Errorf("get controller object: %w", err) + } + + // Cast back to metav1.Object if possible + if metaObj, ok := any(unObj).(metav1.Object); ok { + current = metaObj + } else { + return nil, fmt.Errorf("unexpected type for controller object %s/%s", controllerRef.Kind, controllerRef.Name) + } + } +} diff --git a/internal/utils/owner_ref_utils_test.go b/internal/utils/owner_ref_utils_test.go index 5b77b531..16d8386b 100644 --- a/internal/utils/owner_ref_utils_test.go +++ b/internal/utils/owner_ref_utils_test.go @@ -140,3 +140,127 @@ func TestFindRootOwnerReference(t *testing.T) { require.Equal(t, "ReplicaSet", rootRef.Kind) }) } + +func TestFindRootControllerRef(t *testing.T) { + // Prepare the scheme + sch := runtime.NewScheme() + require.NoError(t, corev1.AddToScheme(sch)) + require.NoError(t, appsv1.AddToScheme(sch)) + + t.Run("no controller returns nil", func(t *testing.T) { + pod := &corev1.Pod{ + TypeMeta: metav1.TypeMeta{ + APIVersion: "v1", + Kind: "Pod", + }, + ObjectMeta: metav1.ObjectMeta{ + Name: "mypod", + Namespace: "default", + UID: "uid-pod", + }, + } + + c := fake.NewClientBuilder().WithScheme(sch).WithObjects(pod).Build() + + rootRef, err := utils.FindRootControllerRef(context.TODO(), c, pod) + require.NoError(t, err) + require.Nil(t, rootRef) + }) + + t.Run("hierarchy returns deployment", func(t *testing.T) { + controller := true + deployment := &appsv1.Deployment{ + TypeMeta: metav1.TypeMeta{ + APIVersion: "apps/v1", + Kind: "Deployment", + }, + ObjectMeta: metav1.ObjectMeta{ + Name: "mydeploy", + Namespace: "default", + UID: "uid-deploy", + }, + } + + rs := &appsv1.ReplicaSet{ + TypeMeta: metav1.TypeMeta{ + APIVersion: "apps/v1", + Kind: "ReplicaSet", + }, + ObjectMeta: metav1.ObjectMeta{ + Name: "myrs", + Namespace: "default", + UID: "uid-rs", + OwnerReferences: []metav1.OwnerReference{ + { + APIVersion: "apps/v1", + Kind: "Deployment", + Name: "mydeploy", + UID: deployment.UID, + Controller: &controller, + }, + }, + }, + } + + pod := &corev1.Pod{ + TypeMeta: metav1.TypeMeta{ + APIVersion: "v1", + Kind: "Pod", + }, + ObjectMeta: metav1.ObjectMeta{ + Name: "mypod", + Namespace: "default", + UID: "uid-pod", + OwnerReferences: []metav1.OwnerReference{ + { + APIVersion: "apps/v1", + Kind: "ReplicaSet", + Name: "myrs", + UID: rs.UID, + Controller: &controller, + }, + }, + }, + } + + c := fake.NewClientBuilder().WithScheme(sch).WithObjects(pod, rs, deployment).Build() + + rootRef, err := utils.FindRootControllerRef(context.TODO(), c, pod) + require.NoError(t, err) + require.NotNil(t, rootRef) + require.Equal(t, "mydeploy", rootRef.Name) + require.Equal(t, "Deployment", rootRef.Kind) + }) + + t.Run("missing controller returns last found ref", func(t *testing.T) { + controller := true + pod := &corev1.Pod{ + TypeMeta: metav1.TypeMeta{ + APIVersion: "v1", + Kind: "Pod", + }, + ObjectMeta: metav1.ObjectMeta{ + Name: "mypod", + Namespace: "default", + UID: "uid-pod", + OwnerReferences: []metav1.OwnerReference{ + { + APIVersion: "apps/v1", + Kind: "ReplicaSet", + Name: "missing-rs", + UID: "uid-missing", + Controller: &controller, + }, + }, + }, + } + + c := fake.NewClientBuilder().WithScheme(sch).WithObjects(pod).Build() + + rootRef, err := utils.FindRootControllerRef(context.TODO(), c, pod) + require.NoError(t, err) + require.NotNil(t, rootRef) + require.Equal(t, "missing-rs", rootRef.Name) + require.Equal(t, "ReplicaSet", rootRef.Kind) + }) +} diff --git a/internal/webhook/v1/pod_webhook.go b/internal/webhook/v1/pod_webhook.go index 6c54113d..9c496c27 100644 --- a/internal/webhook/v1/pod_webhook.go +++ b/internal/webhook/v1/pod_webhook.go @@ -122,19 +122,18 @@ func (m *TensorFusionPodMutator) Handle(ctx context.Context, req admission.Reque podCounterAnnotationKey = podCounterKey } - if tfInfo.PendingSetPodAsOwner { - pod.Annotations[constants.SetPendingOwnedWorkloadAnnotation] = tfInfo.WorkloadName - } - pool := &tfv1.GPUPool{} if err := m.Client.Get(ctx, client.ObjectKey{Name: tfInfo.Profile.PoolName}, pool); err != nil { return admission.Errored(http.StatusInternalServerError, fmt.Errorf("gpu pool(%s) does not exist", tfInfo.Profile.PoolName)) } - workload := &tfv1.TensorFusionWorkload{} - if tfInfo.GenWorkload { - if err := m.createOrUpdateWorkload(ctx, pod, &tfInfo, workload, pool); err != nil { - return admission.Errored(http.StatusInternalServerError, fmt.Errorf("create tf workload: %w", err)) + if workload, err := m.createOrUpdateWorkload(ctx, pod, &tfInfo, pool); err != nil { + return admission.Errored(http.StatusInternalServerError, fmt.Errorf("create tf workload: %w", err)) + } else { + // Pod mutating webhook can not get Pod UID, + // thus need pod controller to set the controller reference + if controllerRef := metav1.GetControllerOfNoCopy(workload); controllerRef == nil { + pod.Annotations[constants.SetPendingOwnedWorkloadAnnotation] = tfInfo.WorkloadName } } @@ -201,7 +200,11 @@ func (m *TensorFusionPodMutator) InjectDecoder(d admission.Decoder) error { return nil } -func (m *TensorFusionPodMutator) createOrUpdateWorkload(ctx context.Context, pod *corev1.Pod, tfInfo *utils.TensorFusionInfo, workload *tfv1.TensorFusionWorkload, pool *tfv1.GPUPool) error { +func (m *TensorFusionPodMutator) createOrUpdateWorkload( + ctx context.Context, + pod *corev1.Pod, + tfInfo *utils.TensorFusionInfo, + pool *tfv1.GPUPool) (*tfv1.TensorFusionWorkload, error) { // Create the desired spec for comparison desiredSpec := tfv1.WorkloadProfileSpec{ Replicas: nil, @@ -214,13 +217,12 @@ func (m *TensorFusionPodMutator) createOrUpdateWorkload(ctx context.Context, pod AutoScalingConfig: tfInfo.Profile.AutoScalingConfig, } + workload := &tfv1.TensorFusionWorkload{} err := m.Client.Get(ctx, client.ObjectKey{Name: tfInfo.WorkloadName, Namespace: pod.Namespace}, workload) if err != nil { if !errors.IsNotFound(err) { - return fmt.Errorf("failed to get workload: %w", err) + return nil, fmt.Errorf("failed to get workload: %w", err) } - // find root owner references of pod - firstLevelOwnerRef := utils.FindFirstLevelOwnerReference(pod) // Create a new workload workload = &tfv1.TensorFusionWorkload{ @@ -242,25 +244,42 @@ func (m *TensorFusionPodMutator) createOrUpdateWorkload(ctx context.Context, pod workload.Annotations[constants.DisableFeaturesAnnotation] = pod.Labels[constants.DisableFeaturesAnnotation] } - if firstLevelOwnerRef != nil { - workload.OwnerReferences = []metav1.OwnerReference{*firstLevelOwnerRef} + if controllerRef := metav1.GetControllerOf(pod); controllerRef != nil { + workload.OwnerReferences = []metav1.OwnerReference{*controllerRef} } if err := m.Client.Create(ctx, workload); err != nil { - return fmt.Errorf("failed to create workload: %w", err) + return nil, fmt.Errorf("failed to create workload: %w", err) + } + return workload, nil + } + + podControllerRef := metav1.GetControllerOf(pod) + workloadControllerRef := metav1.GetControllerOf(workload) + if !isSameControllerRef(podControllerRef, workloadControllerRef) || + !equality.Semantic.DeepEqual(workload.Spec, desiredSpec) { + patch := client.MergeFrom(workload.DeepCopy()) + if podControllerRef != nil { + workload.OwnerReferences = []metav1.OwnerReference{*podControllerRef} + } else { + workload.OwnerReferences = []metav1.OwnerReference{} } - return nil - } - - // Compare the entire spec at once - if !equality.Semantic.DeepEqual(workload.Spec, desiredSpec) { workload.Spec = desiredSpec - // TODO retry on conflict - if err := m.Client.Update(ctx, workload); err != nil { - return fmt.Errorf("failed to update workload: %w", err) + if err := m.Client.Patch(ctx, workload, patch); err != nil { + return nil, fmt.Errorf("failed to patch workload: %w", err) } } - return nil + return workload, nil +} + +func isSameControllerRef(a, b *metav1.OwnerReference) bool { + if a == nil && b == nil { + return true + } + if a == nil || b == nil { + return false + } + return a.UID == b.UID } func (m *TensorFusionPodMutator) patchTFClient( diff --git a/internal/webhook/v1/pod_webhook_test.go b/internal/webhook/v1/pod_webhook_test.go index 374f2620..ac93f21c 100644 --- a/internal/webhook/v1/pod_webhook_test.go +++ b/internal/webhook/v1/pod_webhook_test.go @@ -595,6 +595,57 @@ var _ = Describe("TensorFusionPodMutator", func() { Expect(tfInfo.Profile.Qos).To(Equal(tfv1.QoSHigh)) Expect(*tfInfo.EnabledReplicas).To(Equal(int32(3))) }) + + It("should treat generateName as workload name if the pod has no controllerRef", func() { + pod := &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Namespace: "default", + GenerateName: "test-name", + Annotations: map[string]string{ + constants.GpuPoolKey: "mock", + }, + }, + Spec: corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "test-container", + }, + }, + }, + } + tfInfo, _ := ParseTensorFusionInfo(ctx, k8sClient, pod) + Expect(tfInfo.WorkloadName).To(HavePrefix("test-name")) + }) + + It("should treat controller name as workload name if the pod has controllerRef", func() { + pod := &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Namespace: "default", + GenerateName: "test-name", + Annotations: map[string]string{ + constants.GpuPoolKey: "mock", + }, + OwnerReferences: []metav1.OwnerReference{ + { + APIVersion: "apps/v1", + Kind: "ReplicaSet", + Name: "test-rs", + UID: "rs-uid", + Controller: ptr.To(true), + }, + }, + }, + Spec: corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "test-container", + }, + }, + }, + } + tfInfo, _ := ParseTensorFusionInfo(ctx, k8sClient, pod) + Expect(tfInfo.WorkloadName).To(Equal("test-rs")) + }) }) Context("patchTFClient", func() { @@ -622,4 +673,137 @@ var _ = Describe("TensorFusionPodMutator", func() { Expect(len(patch)).To(BeNumerically(">=", 2)) }) }) + + Context("when handling workload", func() { + It("should update workload's controllerRef same with Pod's controllerRef", func() { + expectedRef := metav1.OwnerReference{ + APIVersion: "apps/v1", + Kind: "ReplicaSet", + Name: "test-rs", + UID: "rs-uid", + Controller: ptr.To(true), + } + pod := &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + GenerateName: "test-name", + Labels: map[string]string{ + constants.TensorFusionEnabledLabelKey: "true", + }, + Annotations: map[string]string{ + constants.GpuPoolKey: "mock", + }, + OwnerReferences: []metav1.OwnerReference{expectedRef}, + }, + Spec: corev1.PodSpec{ + Containers: []corev1.Container{{ + Name: "main", + Image: "test-image", + }}, + }, + } + podBytes, err := json.Marshal(pod) + Expect(err).NotTo(HaveOccurred()) + + req := admission.Request{ + AdmissionRequest: admissionv1.AdmissionRequest{ + Object: runtime.RawExtension{ + Raw: podBytes, + }, + Operation: admissionv1.Create, + Namespace: "default", + }, + } + + resp := mutator.Handle(ctx, req) + Expect(resp.Allowed).To(BeTrue()) + Expect(pod.Annotations[constants.SetPendingOwnedWorkloadAnnotation]).To(BeEmpty()) + + Eventually(func(g Gomega) { + workload := &tfv1.TensorFusionWorkload{} + g.Expect(k8sClient.Get(ctx, + client.ObjectKey{ + Name: expectedRef.Name, + Namespace: "default", + }, workload)).To(Succeed()) + gotRef := metav1.GetControllerOfNoCopy(workload) + g.Expect(*gotRef).To(Equal(expectedRef)) + }).Should(Succeed()) + + newExpectedRef := metav1.OwnerReference{ + APIVersion: "apps/v1", + Kind: "ReplicaSet", + Name: "new-test-rs", + UID: "new-rs-uid", + Controller: ptr.To(true), + } + pod.OwnerReferences = []metav1.OwnerReference{newExpectedRef} + podBytes, err = json.Marshal(pod) + Expect(err).NotTo(HaveOccurred()) + + req = admission.Request{ + AdmissionRequest: admissionv1.AdmissionRequest{ + Object: runtime.RawExtension{ + Raw: podBytes, + }, + Operation: admissionv1.Create, + Namespace: "default", + }, + } + + resp = mutator.Handle(ctx, req) + Expect(resp.Allowed).To(BeTrue()) + Expect(pod.Annotations[constants.SetPendingOwnedWorkloadAnnotation]).To(BeEmpty()) + + Eventually(func(g Gomega) { + workload := &tfv1.TensorFusionWorkload{} + g.Expect(k8sClient.Get(ctx, + client.ObjectKey{ + Name: newExpectedRef.Name, + Namespace: "default", + }, workload)).To(Succeed()) + gotRef := metav1.GetControllerOfNoCopy(workload) + g.Expect(*gotRef).To(Equal(newExpectedRef)) + }).Should(Succeed()) + }) + + It("should add SetPendingOwnedWorkload annotation to pod when workload has no controllerRef", func() { + pod := &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + GenerateName: "test-name", + Labels: map[string]string{ + constants.TensorFusionEnabledLabelKey: "true", + }, + Annotations: map[string]string{ + constants.GpuPoolKey: "mock", + }, + }, + Spec: corev1.PodSpec{ + Containers: []corev1.Container{{ + Name: "main", + Image: "test-image", + }}, + }, + } + podBytes, err := json.Marshal(pod) + Expect(err).NotTo(HaveOccurred()) + + req := admission.Request{ + AdmissionRequest: admissionv1.AdmissionRequest{ + Object: runtime.RawExtension{ + Raw: podBytes, + }, + Operation: admissionv1.Create, + Namespace: "default", + }, + } + + resp := mutator.Handle(ctx, req) + Expect(resp.Allowed).To(BeTrue()) + annotation, found := lo.Find(resp.Patches, func(patch jsonpatch.JsonPatchOperation) bool { + return patch.Path == "/metadata/annotations/tensor-fusion.ai~1pending-owned-workload" + }) + Expect(found).To(BeTrue()) + Expect(annotation.Value).To(HavePrefix("test-name")) + }) + }) }) diff --git a/internal/webhook/v1/tf_parser.go b/internal/webhook/v1/tf_parser.go index dfd8fd19..c2ba3720 100644 --- a/internal/webhook/v1/tf_parser.go +++ b/internal/webhook/v1/tf_parser.go @@ -47,25 +47,19 @@ func ParseTensorFusionInfo( info.EnabledReplicas = &val32 } - workloadName, ok := pod.Annotations[constants.WorkloadKey] - if !ok { - // auto generate a workload with owner name - info.GenWorkload = true - owner := utils.FindFirstLevelOwnerReference(pod) - if owner == nil { + // Generate the workload name: if the Pod has no controller, use the Pod's name; otherwise, use the root controller's name. + if controllerRef, err := utils.FindRootControllerRef(ctx, k8sClient, pod); err == nil { + if controllerRef != nil { + info.WorkloadName = controllerRef.Name + } else { if pod.Name == "" { info.WorkloadName = pod.GenerateName + "-" + utils.NewShortID(8) } else { info.WorkloadName = pod.Name } - info.PendingSetPodAsOwner = true - } else { - info.WorkloadName = owner.Name } } else { - // when workload is manually created, user can specify workload's replicas - // it remotely connects to lease connection worker when SelectWorker - info.WorkloadName = workloadName + return info, err } workloadProfileName, ok := pod.Annotations[constants.WorkloadProfileAnnotation] From f7e68699fe916b7951e78e17050b3a19ae862ee8 Mon Sep 17 00:00:00 2001 From: knave Date: Tue, 30 Sep 2025 19:29:59 +0800 Subject: [PATCH 2/4] fix: implement specific logic for the Deployment --- internal/webhook/v1/pod_webhook.go | 3 +-- internal/webhook/v1/tf_parser.go | 41 ++++++++++++++++++++++++++++-- 2 files changed, 40 insertions(+), 4 deletions(-) diff --git a/internal/webhook/v1/pod_webhook.go b/internal/webhook/v1/pod_webhook.go index 9c496c27..21f01ec9 100644 --- a/internal/webhook/v1/pod_webhook.go +++ b/internal/webhook/v1/pod_webhook.go @@ -29,7 +29,6 @@ import ( corev1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/equality" "k8s.io/apimachinery/pkg/api/errors" - "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/util/strategicpatch" ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/client" @@ -53,7 +52,7 @@ func SetupPodWebhookWithManager(mgr ctrl.Manager, portAllocator *portallocator.P webhookServer.Register("/mutate-v1-pod", &admission.Webhook{ Handler: &TensorFusionPodMutator{ - decoder: admission.NewDecoder(runtime.NewScheme()), + decoder: admission.NewDecoder(mgr.GetScheme()), Client: mgr.GetClient(), portAllocator: portAllocator, }, diff --git a/internal/webhook/v1/tf_parser.go b/internal/webhook/v1/tf_parser.go index c2ba3720..b0fceb5f 100644 --- a/internal/webhook/v1/tf_parser.go +++ b/internal/webhook/v1/tf_parser.go @@ -10,8 +10,11 @@ import ( "github.com/NexusGPU/tensor-fusion/internal/constants" "github.com/NexusGPU/tensor-fusion/internal/gpuallocator" "github.com/NexusGPU/tensor-fusion/internal/utils" + appsv1 "k8s.io/api/apps/v1" corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/errors" "k8s.io/apimachinery/pkg/api/resource" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "sigs.k8s.io/controller-runtime/pkg/client" ) @@ -47,8 +50,11 @@ func ParseTensorFusionInfo( info.EnabledReplicas = &val32 } - // Generate the workload name: if the Pod has no controller, use the Pod's name; otherwise, use the root controller's name. - if controllerRef, err := utils.FindRootControllerRef(ctx, k8sClient, pod); err == nil { + // Generate the workload name: + // If the Pod has no controller, use the Pod's name; + // if it is controlled by a Deployment, return the Deployment's name; + // otherwise, return the name of the first-level controller. + if controllerRef, err := getPodControllerRef(ctx, k8sClient, pod); err == nil { if controllerRef != nil { info.WorkloadName = controllerRef.Name } else { @@ -254,3 +260,34 @@ func handleDedicatedGPU(pod *corev1.Pod, workloadProfile *tfv1.WorkloadProfile) workloadProfile.Spec.Resources.Limits.Vram = resource.Vram return nil } + +func getPodControllerRef(ctx context.Context, c client.Client, pod *corev1.Pod) (*metav1.OwnerReference, error) { + podControllerRef := metav1.GetControllerOf(pod) + if podControllerRef == nil { + return nil, nil + } + + switch podControllerRef.Kind { + case "ReplicaSet": + { + // Special handling for Deployment resources + rs := &appsv1.ReplicaSet{} + if err := c.Get(ctx, client.ObjectKey{ + Namespace: pod.Namespace, + Name: podControllerRef.Name, + }, rs); err != nil { + if errors.IsNotFound(err) { + return podControllerRef, nil + } + return nil, fmt.Errorf("failed to get ReplicaSet: %w", err) + } + rsContollerRef := metav1.GetControllerOf(rs) + if rsContollerRef != nil && rsContollerRef.Kind == "Deployment" { + // If controlled by a Deployment, return the controllerRef of rs + return rsContollerRef, nil + } + } + } + + return podControllerRef, nil +} From df6b7da7bd29021caec7a8ef71c4d69864e363ca Mon Sep 17 00:00:00 2001 From: knave Date: Tue, 30 Sep 2025 19:39:48 +0800 Subject: [PATCH 3/4] fix: check return value when calling addToScheme --- cmd/main.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmd/main.go b/cmd/main.go index 3b75421f..216b3ff7 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -109,7 +109,7 @@ func init() { } karpenterScheme.Register(&karpv1.NodeClaim{}, &karpv1.NodeClaimList{}) karpenterScheme.Register(&karpv1.NodePool{}, &karpv1.NodePoolList{}) - karpenterScheme.AddToScheme(scheme) + utilruntime.Must(karpenterScheme.AddToScheme(scheme)) } //nolint:gocyclo From 2dece9ed46aaa886c2763179e0c4c4e5d9b1743f Mon Sep 17 00:00:00 2001 From: knave Date: Fri, 10 Oct 2025 11:57:08 +0800 Subject: [PATCH 4/4] fix: align the Workload owner with the Pod controller --- internal/utils/compose.go | 12 ++- internal/utils/owner_ref_utils.go | 45 +++++++++ internal/utils/owner_ref_utils_test.go | 128 ++++++++++++++++++++++++ internal/webhook/v1/pod_webhook.go | 24 +---- internal/webhook/v1/pod_webhook_test.go | 85 +++++++++++++--- internal/webhook/v1/tf_parser.go | 42 +------- 6 files changed, 259 insertions(+), 77 deletions(-) diff --git a/internal/utils/compose.go b/internal/utils/compose.go index cef36f89..f9a9a3d8 100644 --- a/internal/utils/compose.go +++ b/internal/utils/compose.go @@ -12,6 +12,7 @@ import ( "github.com/samber/lo" v1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/resource" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/util/intstr" "k8s.io/utils/ptr" ) @@ -73,11 +74,12 @@ var featureShortcutMap = map[string]struct { } type TensorFusionInfo struct { - Profile *tfv1.WorkloadProfileSpec - DynamicReplicas bool - EnabledReplicas *int32 - WorkloadName string - ContainerNames []string + Profile *tfv1.WorkloadProfileSpec + DynamicReplicas bool + EnabledReplicas *int32 + WorkloadName string + PodControllerRef *metav1.OwnerReference + ContainerNames []string } func AddOrOverrideTFClientMissingAnnotationsBeforePatch(pod *v1.Pod, tfInfo TensorFusionInfo) { diff --git a/internal/utils/owner_ref_utils.go b/internal/utils/owner_ref_utils.go index 97c8cf4f..e440f40b 100644 --- a/internal/utils/owner_ref_utils.go +++ b/internal/utils/owner_ref_utils.go @@ -4,6 +4,9 @@ import ( "context" "fmt" + appsv1 "k8s.io/api/apps/v1" + batchv1 "k8s.io/api/batch/v1" + corev1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/errors" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" @@ -136,3 +139,45 @@ func FindRootControllerRef(ctx context.Context, c client.Client, obj metav1.Obje } } } + +// GetPodControllerRef returns the controller reference for a Pod. +// For Pods that are indirectly controlled (e.g., by a Deployment or CronJob), return the indirect controller. +// For other cases, it returns the direct controller reference of the Pod. +// If the Pod has no controller reference, it returns nil. +func GetPodControllerRef(ctx context.Context, c client.Client, pod *corev1.Pod) (*metav1.OwnerReference, error) { + podControllerRef := metav1.GetControllerOf(pod) + if podControllerRef == nil { + return nil, nil + } + + getControllerRef := func(obj client.Object) (*metav1.OwnerReference, error) { + if err := c.Get(ctx, client.ObjectKey{ + Namespace: pod.Namespace, + Name: podControllerRef.Name, + }, obj); err != nil { + if errors.IsNotFound(err) { + return podControllerRef, nil + } + return nil, fmt.Errorf("failed to get %T: %w", obj, err) + } + return metav1.GetControllerOf(obj), nil + } + + switch podControllerRef.Kind { + case "ReplicaSet": + if parentRef, err := getControllerRef(&appsv1.ReplicaSet{}); err != nil { + return nil, err + } else if parentRef != nil && parentRef.Kind == "Deployment" { + return parentRef, nil + } + + case "Job": + if parentRef, err := getControllerRef(&batchv1.Job{}); err != nil { + return nil, err + } else if parentRef != nil && parentRef.Kind == "CronJob" { + return parentRef, nil + } + } + + return podControllerRef, nil +} diff --git a/internal/utils/owner_ref_utils_test.go b/internal/utils/owner_ref_utils_test.go index 16d8386b..6cca7f76 100644 --- a/internal/utils/owner_ref_utils_test.go +++ b/internal/utils/owner_ref_utils_test.go @@ -5,6 +5,7 @@ import ( "testing" appsv1 "k8s.io/api/apps/v1" + batchv1 "k8s.io/api/batch/v1" corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime" @@ -264,3 +265,130 @@ func TestFindRootControllerRef(t *testing.T) { require.Equal(t, "ReplicaSet", rootRef.Kind) }) } + +func TestGetPodControllerRef(t *testing.T) { + // Prepare the scheme + sch := runtime.NewScheme() + require.NoError(t, corev1.AddToScheme(sch)) + require.NoError(t, appsv1.AddToScheme(sch)) + require.NoError(t, batchv1.AddToScheme(sch)) + + t.Run("pod with no controller returns nil", func(t *testing.T) { + pod := &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: "mypod", + Namespace: "default", + }, + } + + c := fake.NewClientBuilder().WithScheme(sch).WithObjects(pod).Build() + + ref, err := utils.GetPodControllerRef(context.TODO(), c, pod) + require.NoError(t, err) + require.Nil(t, ref) + }) + + t.Run("pod owned by replicaset owned by deployment returns deployment ref", func(t *testing.T) { + controller := true + deployment := &appsv1.Deployment{ + ObjectMeta: metav1.ObjectMeta{ + Name: "mydeploy", + Namespace: "default", + UID: "uid-deploy", + }, + } + + rs := &appsv1.ReplicaSet{ + ObjectMeta: metav1.ObjectMeta{ + Name: "myrs", + Namespace: "default", + UID: "uid-rs", + OwnerReferences: []metav1.OwnerReference{ + { + APIVersion: "apps/v1", + Kind: "Deployment", + Name: "mydeploy", + UID: deployment.UID, + Controller: &controller, + }, + }, + }, + } + + pod := &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: "mypod", + Namespace: "default", + OwnerReferences: []metav1.OwnerReference{ + { + APIVersion: "apps/v1", + Kind: "ReplicaSet", + Name: "myrs", + UID: rs.UID, + Controller: &controller, + }, + }, + }, + } + + c := fake.NewClientBuilder().WithScheme(sch).WithObjects(pod, rs, deployment).Build() + + ref, err := utils.GetPodControllerRef(context.TODO(), c, pod) + require.NoError(t, err) + require.NotNil(t, ref) + require.Equal(t, "mydeploy", ref.Name) + require.Equal(t, "Deployment", ref.Kind) + }) + + t.Run("pod owned by job owned by cronjob returns cronjob ref", func(t *testing.T) { + controller := true + cronjob := &batchv1.CronJob{ + ObjectMeta: metav1.ObjectMeta{ + Name: "mycronjob", + Namespace: "default", + UID: "uid-cronjob", + }, + } + + job := &batchv1.Job{ + ObjectMeta: metav1.ObjectMeta{ + Name: "myjob", + Namespace: "default", + UID: "uid-job", + OwnerReferences: []metav1.OwnerReference{ + { + APIVersion: "batch/v1", + Kind: "CronJob", + Name: "mycronjob", + UID: cronjob.UID, + Controller: &controller, + }, + }, + }, + } + + pod := &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: "mypod", + Namespace: "default", + OwnerReferences: []metav1.OwnerReference{ + { + APIVersion: "batch/v1", + Kind: "Job", + Name: "myjob", + UID: job.UID, + Controller: &controller, + }, + }, + }, + } + + c := fake.NewClientBuilder().WithScheme(sch).WithObjects(pod, job, cronjob).Build() + + ref, err := utils.GetPodControllerRef(context.TODO(), c, pod) + require.NoError(t, err) + require.NotNil(t, ref) + require.Equal(t, "mycronjob", ref.Name) + require.Equal(t, "CronJob", ref.Kind) + }) +} diff --git a/internal/webhook/v1/pod_webhook.go b/internal/webhook/v1/pod_webhook.go index 21f01ec9..cdb8cfae 100644 --- a/internal/webhook/v1/pod_webhook.go +++ b/internal/webhook/v1/pod_webhook.go @@ -243,8 +243,8 @@ func (m *TensorFusionPodMutator) createOrUpdateWorkload( workload.Annotations[constants.DisableFeaturesAnnotation] = pod.Labels[constants.DisableFeaturesAnnotation] } - if controllerRef := metav1.GetControllerOf(pod); controllerRef != nil { - workload.OwnerReferences = []metav1.OwnerReference{*controllerRef} + if tfInfo.PodControllerRef != nil { + workload.OwnerReferences = []metav1.OwnerReference{*tfInfo.PodControllerRef} } if err := m.Client.Create(ctx, workload); err != nil { @@ -253,16 +253,8 @@ func (m *TensorFusionPodMutator) createOrUpdateWorkload( return workload, nil } - podControllerRef := metav1.GetControllerOf(pod) - workloadControllerRef := metav1.GetControllerOf(workload) - if !isSameControllerRef(podControllerRef, workloadControllerRef) || - !equality.Semantic.DeepEqual(workload.Spec, desiredSpec) { + if !equality.Semantic.DeepEqual(workload.Spec, desiredSpec) { patch := client.MergeFrom(workload.DeepCopy()) - if podControllerRef != nil { - workload.OwnerReferences = []metav1.OwnerReference{*podControllerRef} - } else { - workload.OwnerReferences = []metav1.OwnerReference{} - } workload.Spec = desiredSpec if err := m.Client.Patch(ctx, workload, patch); err != nil { return nil, fmt.Errorf("failed to patch workload: %w", err) @@ -271,16 +263,6 @@ func (m *TensorFusionPodMutator) createOrUpdateWorkload( return workload, nil } -func isSameControllerRef(a, b *metav1.OwnerReference) bool { - if a == nil && b == nil { - return true - } - if a == nil || b == nil { - return false - } - return a.UID == b.UID -} - func (m *TensorFusionPodMutator) patchTFClient( pod *corev1.Pod, pool *tfv1.GPUPool, diff --git a/internal/webhook/v1/pod_webhook_test.go b/internal/webhook/v1/pod_webhook_test.go index ac93f21c..5d8a46d8 100644 --- a/internal/webhook/v1/pod_webhook_test.go +++ b/internal/webhook/v1/pod_webhook_test.go @@ -675,11 +675,11 @@ var _ = Describe("TensorFusionPodMutator", func() { }) Context("when handling workload", func() { - It("should update workload's controllerRef same with Pod's controllerRef", func() { + It("should set the workload owner same with Pod's controllerRef", func() { expectedRef := metav1.OwnerReference{ APIVersion: "apps/v1", Kind: "ReplicaSet", - Name: "test-rs", + Name: "my-rs", UID: "rs-uid", Controller: ptr.To(true), } @@ -728,19 +728,78 @@ var _ = Describe("TensorFusionPodMutator", func() { gotRef := metav1.GetControllerOfNoCopy(workload) g.Expect(*gotRef).To(Equal(expectedRef)) }).Should(Succeed()) + }) - newExpectedRef := metav1.OwnerReference{ + It("should set the workload owner to controlling deployment if the pod controlled by a deployment", func() { + expectedRef := metav1.OwnerReference{ APIVersion: "apps/v1", - Kind: "ReplicaSet", - Name: "new-test-rs", - UID: "new-rs-uid", + Kind: "Deployment", + Name: "test-deployment", + UID: "deployment-uid", Controller: ptr.To(true), } - pod.OwnerReferences = []metav1.OwnerReference{newExpectedRef} - podBytes, err = json.Marshal(pod) + rs := &appsv1.ReplicaSet{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-rs", + Namespace: "default", + UID: "rs-uid", + OwnerReferences: []metav1.OwnerReference{expectedRef}, + }, + Spec: appsv1.ReplicaSetSpec{ + Selector: &metav1.LabelSelector{ + MatchLabels: map[string]string{ + "app": "test-app", + }, + }, + Template: corev1.PodTemplateSpec{ + ObjectMeta: metav1.ObjectMeta{ + Labels: map[string]string{ + "app": "test-app", + }, + }, + Spec: corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "test-container", + Image: "test-image", + }, + }, + }, + }, + }, + } + + Expect(k8sClient.Create(ctx, rs)).To(Succeed()) + pod := &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + GenerateName: "test-name", + Labels: map[string]string{ + constants.TensorFusionEnabledLabelKey: "true", + }, + Annotations: map[string]string{ + constants.GpuPoolKey: "mock", + }, + OwnerReferences: []metav1.OwnerReference{ + { + APIVersion: "apps/v1", + Kind: "ReplicaSet", + Name: "test-rs", + UID: "rs-uid", + Controller: ptr.To(true), + }, + }, + }, + Spec: corev1.PodSpec{ + Containers: []corev1.Container{{ + Name: "main", + Image: "test-image", + }}, + }, + } + podBytes, err := json.Marshal(pod) Expect(err).NotTo(HaveOccurred()) - req = admission.Request{ + req := admission.Request{ AdmissionRequest: admissionv1.AdmissionRequest{ Object: runtime.RawExtension{ Raw: podBytes, @@ -750,7 +809,7 @@ var _ = Describe("TensorFusionPodMutator", func() { }, } - resp = mutator.Handle(ctx, req) + resp := mutator.Handle(ctx, req) Expect(resp.Allowed).To(BeTrue()) Expect(pod.Annotations[constants.SetPendingOwnedWorkloadAnnotation]).To(BeEmpty()) @@ -758,12 +817,14 @@ var _ = Describe("TensorFusionPodMutator", func() { workload := &tfv1.TensorFusionWorkload{} g.Expect(k8sClient.Get(ctx, client.ObjectKey{ - Name: newExpectedRef.Name, + Name: expectedRef.Name, Namespace: "default", }, workload)).To(Succeed()) gotRef := metav1.GetControllerOfNoCopy(workload) - g.Expect(*gotRef).To(Equal(newExpectedRef)) + g.Expect(*gotRef).To(Equal(expectedRef)) }).Should(Succeed()) + + Expect(k8sClient.Delete(ctx, rs)).Should(Succeed()) }) It("should add SetPendingOwnedWorkload annotation to pod when workload has no controllerRef", func() { diff --git a/internal/webhook/v1/tf_parser.go b/internal/webhook/v1/tf_parser.go index b0fceb5f..73540667 100644 --- a/internal/webhook/v1/tf_parser.go +++ b/internal/webhook/v1/tf_parser.go @@ -10,11 +10,8 @@ import ( "github.com/NexusGPU/tensor-fusion/internal/constants" "github.com/NexusGPU/tensor-fusion/internal/gpuallocator" "github.com/NexusGPU/tensor-fusion/internal/utils" - appsv1 "k8s.io/api/apps/v1" corev1 "k8s.io/api/core/v1" - "k8s.io/apimachinery/pkg/api/errors" "k8s.io/apimachinery/pkg/api/resource" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "sigs.k8s.io/controller-runtime/pkg/client" ) @@ -50,13 +47,11 @@ func ParseTensorFusionInfo( info.EnabledReplicas = &val32 } - // Generate the workload name: - // If the Pod has no controller, use the Pod's name; - // if it is controlled by a Deployment, return the Deployment's name; - // otherwise, return the name of the first-level controller. - if controllerRef, err := getPodControllerRef(ctx, k8sClient, pod); err == nil { + // Generate the workload name + if controllerRef, err := utils.GetPodControllerRef(ctx, k8sClient, pod); err == nil { if controllerRef != nil { info.WorkloadName = controllerRef.Name + info.PodControllerRef = controllerRef } else { if pod.Name == "" { info.WorkloadName = pod.GenerateName + "-" + utils.NewShortID(8) @@ -260,34 +255,3 @@ func handleDedicatedGPU(pod *corev1.Pod, workloadProfile *tfv1.WorkloadProfile) workloadProfile.Spec.Resources.Limits.Vram = resource.Vram return nil } - -func getPodControllerRef(ctx context.Context, c client.Client, pod *corev1.Pod) (*metav1.OwnerReference, error) { - podControllerRef := metav1.GetControllerOf(pod) - if podControllerRef == nil { - return nil, nil - } - - switch podControllerRef.Kind { - case "ReplicaSet": - { - // Special handling for Deployment resources - rs := &appsv1.ReplicaSet{} - if err := c.Get(ctx, client.ObjectKey{ - Namespace: pod.Namespace, - Name: podControllerRef.Name, - }, rs); err != nil { - if errors.IsNotFound(err) { - return podControllerRef, nil - } - return nil, fmt.Errorf("failed to get ReplicaSet: %w", err) - } - rsContollerRef := metav1.GetControllerOf(rs) - if rsContollerRef != nil && rsContollerRef.Kind == "Deployment" { - // If controlled by a Deployment, return the controllerRef of rs - return rsContollerRef, nil - } - } - } - - return podControllerRef, nil -}