Skip to content
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions pkg/runtime/framework/core/framework_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ import (
testingutil "github.com/kubeflow/trainer/v2/pkg/util/testing"
)

var ignoreSyncPodSets = cmpopts.IgnoreFields(runtime.Info{}, "SyncPodSets")

// TODO: We should introduce mock plugins and use plugins in this framework testing.
// After we migrate the actual plugins to mock one for testing data,
// we can delegate the actual plugin testing to each plugin directories, and implement detailed unit testing.
Expand Down Expand Up @@ -329,7 +331,7 @@ func TestRunEnforceMLPolicyPlugins(t *testing.T) {
if diff := cmp.Diff(tc.wantError, err, cmpopts.EquateErrors()); len(diff) != 0 {
t.Errorf("Unexpected error (-want,+got): %s", diff)
}
if diff := cmp.Diff(tc.wantRuntimeInfo, tc.runtimeInfo, cmpopts.EquateEmpty()); len(diff) != 0 {
if diff := cmp.Diff(tc.wantRuntimeInfo, tc.runtimeInfo, cmpopts.EquateEmpty(), ignoreSyncPodSets); len(diff) != 0 {
t.Errorf("Unexpected runtime.Info (-want,+got): %s", diff)
}
})
Expand Down Expand Up @@ -424,7 +426,7 @@ func TestRunEnforcePodGroupPolicyPlugins(t *testing.T) {
if diff := cmp.Diff(tc.wantError, err, cmpopts.EquateErrors()); len(diff) != 0 {
t.Errorf("Unexpected error (-want,+got): %s", diff)
}
if diff := cmp.Diff(tc.wantRuntimeInfo, tc.runtimeInfo); len(diff) != 0 {
if diff := cmp.Diff(tc.wantRuntimeInfo, tc.runtimeInfo, ignoreSyncPodSets); len(diff) != 0 {
t.Errorf("Unexpected runtime.Info (-want,+got): %s", diff)
}
})
Expand Down Expand Up @@ -2215,7 +2217,7 @@ test-job-node-0-1.test-job slots=1
t.Errorf("Unexpected errors (-want,+got):\n%s", diff)
}

if diff := cmp.Diff(tc.wantRuntimeInfo, tc.runtimeInfo, testingutil.PodSetEndpointsCmpOpts); len(diff) != 0 {
if diff := cmp.Diff(tc.wantRuntimeInfo, tc.runtimeInfo, testingutil.PodSetEndpointsCmpOpts, ignoreSyncPodSets); len(diff) != 0 {
t.Errorf("Unexpected runtime.Info (-want,+got)\n%s", diff)
}

Expand Down Expand Up @@ -2617,7 +2619,7 @@ func TestPodNetworkPlugins(t *testing.T) {
if diff := cmp.Diff(tc.wantError, err); len(diff) != 0 {
t.Errorf("Unexpected error (-want,+got):\n%s", diff)
}
if diff := cmp.Diff(tc.wantRuntimeInfo, tc.runtimeInfo, testingutil.PodSetEndpointsCmpOpts); len(diff) != 0 {
if diff := cmp.Diff(tc.wantRuntimeInfo, tc.runtimeInfo, testingutil.PodSetEndpointsCmpOpts, ignoreSyncPodSets); len(diff) != 0 {
t.Errorf("Unexpected runtimeInfo (-want,+got):\n%s", diff)
}
})
Expand Down
26 changes: 14 additions & 12 deletions pkg/runtime/framework/plugins/coscheduling/coscheduling_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ import (
utiltesting "github.com/kubeflow/trainer/v2/pkg/util/testing"
)

var ignoreSyncPodSets = cmpopts.IgnoreFields(runtime.Info{}, "SyncPodSets")

func TestCoScheduling(t *testing.T) {
objCmpOpts := []gocmp.Option{
cmpopts.SortSlices(func(a, b apiruntime.Object) int {
Expand Down Expand Up @@ -558,18 +560,18 @@ func TestCoScheduling(t *testing.T) {
if err != nil {
t.Fatalf("Failed to create plugin: %v", err)
}
err = plugin.(framework.EnforcePodGroupPolicyPlugin).EnforcePodGroupPolicy(tc.info, tc.trainJob)
if diff := gocmp.Diff(tc.wantPodGroupPolicyError, err, cmpopts.EquateErrors()); len(diff) != 0 {
t.Errorf("Unexpected error from EnforcePodGroupPolicy (-want,+got):\n%s", diff)
}
if diff := gocmp.Diff(tc.wantInfo, tc.info,
cmpopts.SortSlices(func(a, b string) bool { return a < b }),
cmpopts.SortMaps(func(a, b int) bool { return a < b }),
); len(diff) != 0 {
t.Errorf("Unexpected info from EnforcePodGroupPolicy (-want,+got):\n%s", diff)
}

var objs []any
err = plugin.(framework.EnforcePodGroupPolicyPlugin).EnforcePodGroupPolicy(tc.info, tc.trainJob)
if diff := gocmp.Diff(tc.wantPodGroupPolicyError, err, cmpopts.EquateErrors()); len(diff) != 0 {
t.Errorf("Unexpected error from EnforcePodGroupPolicy (-want,+got):\n%s", diff)
}
if diff := gocmp.Diff(tc.wantInfo, tc.info,
cmpopts.SortSlices(func(a, b string) bool { return a < b }),
cmpopts.SortMaps(func(a, b int) bool { return a < b }),
ignoreSyncPodSets,
); len(diff) != 0 {
t.Errorf("Unexpected info from EnforcePodGroupPolicy (-want,+got):\n%s", diff)
}
var objs []any
objs, err = plugin.(framework.ComponentBuilderPlugin).Build(ctx, tc.info, tc.trainJob)
if diff := gocmp.Diff(tc.wantBuildError, err, cmpopts.EquateErrors()); len(diff) != 0 {
t.Errorf("Unexpected error from Build (-want, +got): %s", diff)
Expand Down
36 changes: 19 additions & 17 deletions pkg/runtime/framework/plugins/jobset/jobset_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ import (
utiltesting "github.com/kubeflow/trainer/v2/pkg/util/testing"
)

var ignoreSyncPodSets = cmpopts.IgnoreFields(runtime.Info{}, "SyncPodSets")

// TODO: Add tests for all Interfaces.
// REF: https://github.com/kubeflow/trainer/issues/2468

Expand Down Expand Up @@ -315,25 +317,25 @@ func TestJobSet(t *testing.T) {
ctx, cancel = context.WithCancel(ctx)
t.Cleanup(cancel)
cli := utiltesting.NewClientBuilder().Build()
p, err := New(ctx, cli, nil)
if err != nil {
t.Fatalf("Failed to initialize JobSet plugin: %v", err)
}
err = p.(framework.PodNetworkPlugin).IdentifyPodNetwork(tc.info, tc.trainJob)
if diff := cmp.Diff(tc.wantError, err, cmpopts.EquateErrors()); len(diff) != 0 {
t.Errorf("Unexpected error (-want,+got):\n%s", diff)
}
if diff := cmp.Diff(tc.wantInfo, tc.info,
cmpopts.SortSlices(func(a, b string) bool { return a < b }),
cmpopts.SortMaps(func(a, b string) bool { return a < b }),
utiltesting.PodSetEndpointsCmpOpts,
); len(diff) != 0 {
t.Errorf("Unexpected Info from IdentifyPodNetwork (-want,+got):\n%s", diff)
}
})
p, err := New(ctx, cli, nil)
if err != nil {
t.Fatalf("Failed to initialize JobSet plugin: %v", err)
}
err = p.(framework.PodNetworkPlugin).IdentifyPodNetwork(tc.info, tc.trainJob)
if diff := cmp.Diff(tc.wantError, err, cmpopts.EquateErrors()); len(diff) != 0 {
t.Errorf("Unexpected error (-want,+got):\n%s", diff)
}
if diff := cmp.Diff(tc.wantInfo, tc.info,
cmpopts.SortSlices(func(a, b string) bool { return a < b }),
cmpopts.SortMaps(func(a, b string) bool { return a < b }),
utiltesting.PodSetEndpointsCmpOpts,
ignoreSyncPodSets,
); len(diff) != 0 {
t.Errorf("Unexpected Info from IdentifyPodNetwork (-want,+got):\n%s", diff)
}
})
}
}

func TestValidate(t *testing.T) {
cases := map[string]struct {
info *runtime.Info
Expand Down
13 changes: 8 additions & 5 deletions pkg/runtime/framework/plugins/mpi/mpi_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ import (
utiltesting "github.com/kubeflow/trainer/v2/pkg/util/testing"
)

var ignoreSyncPodSets = cmpopts.IgnoreFields(runtime.Info{}, "SyncPodSets")

func TestMPI(t *testing.T) {
objCmpOpts := []gocmp.Option{
cmpopts.SortSlices(func(a, b apiruntime.Object) int {
Expand Down Expand Up @@ -856,11 +858,12 @@ trainJob-node-1-0.trainJob slots=1
if diff := gocmp.Diff(tc.wantMLPolicyError, err, cmpopts.EquateErrors()); len(diff) != 0 {
t.Errorf("Unexpected error from EnforceMLPolicy (-want, +got): %s", diff)
}
if diff := gocmp.Diff(tc.wantInfo, tc.info,
cmpopts.SortSlices(func(a, b string) bool { return a < b }),
cmpopts.SortMaps(func(a, b int) bool { return a < b }),
utiltesting.PodSetEndpointsCmpOpts,
); len(diff) != 0 {
if diff := gocmp.Diff(tc.wantInfo, tc.info,
cmpopts.SortSlices(func(a, b string) bool { return a < b }),
cmpopts.SortMaps(func(a, b int) bool { return a < b }),
utiltesting.PodSetEndpointsCmpOpts,
ignoreSyncPodSets,
); len(diff) != 0 {
t.Errorf("Unexpected info from EnforceMLPolicy (-want, +got): %s", diff)
}
var objs []any
Expand Down
11 changes: 7 additions & 4 deletions pkg/runtime/framework/plugins/plainml/plainml_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ import (
utiltesting "github.com/kubeflow/trainer/v2/pkg/util/testing"
)

var ignoreSyncPodSets = cmpopts.IgnoreFields(runtime.Info{}, "SyncPodSets")

func TestPlainML(t *testing.T) {
cases := map[string]struct {
trainJob *trainer.TrainJob
Expand Down Expand Up @@ -180,10 +182,11 @@ func TestPlainML(t *testing.T) {
if diff := cmp.Diff(tc.wantError, err, cmpopts.EquateErrors()); len(diff) != 0 {
t.Errorf("Unexpected error from EnforceMLPolicy (-want,+got):\n%s", diff)
}
if diff := cmp.Diff(tc.wantInfo, tc.info,
cmpopts.SortSlices(func(a, b string) bool { return a < b }),
cmpopts.SortMaps(func(a, b string) bool { return a < b }),
); len(diff) != 0 {
if diff := cmp.Diff(tc.wantInfo, tc.info,
cmpopts.SortSlices(func(a, b string) bool { return a < b }),
cmpopts.SortMaps(func(a, b string) bool { return a < b }),
ignoreSyncPodSets,
); len(diff) != 0 {
t.Errorf("Unexpected RuntimeInfo (-want,+got):\n%s", diff)
}
})
Expand Down
30 changes: 16 additions & 14 deletions pkg/runtime/framework/plugins/torch/torch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ import (
utiltesting "github.com/kubeflow/trainer/v2/pkg/util/testing"
)

var ignoreSyncPodSets = cmpopts.IgnoreFields(runtime.Info{}, "SyncPodSets")

func TestTorch(t *testing.T) {
cases := map[string]struct {
info *runtime.Info
Expand Down Expand Up @@ -1392,23 +1394,23 @@ func TestTorch(t *testing.T) {
t.Fatalf("Failed to initialize Torch plugin: %v", err)
}

// Test EnforceMLPolicy
err = p.(framework.EnforceMLPolicyPlugin).EnforceMLPolicy(tc.info, tc.trainJob)
if diff := cmp.Diff(tc.wantMLPolicyError, err, cmpopts.EquateErrors()); len(diff) != 0 {
t.Errorf("Unexpected error from EnforceMLPolicy (-want,+got):\n%s", diff)
}
// Test EnforceMLPolicy
err = p.(framework.EnforceMLPolicyPlugin).EnforceMLPolicy(tc.info, tc.trainJob)
if diff := cmp.Diff(tc.wantMLPolicyError, err, cmpopts.EquateErrors()); len(diff) != 0 {
t.Errorf("Unexpected error from EnforceMLPolicy (-want,+got):\n%s", diff)
}

// Validate the entire info object
if diff := cmp.Diff(tc.wantInfo, tc.info,
cmpopts.SortSlices(func(a, b string) bool { return a < b }),
cmpopts.SortMaps(func(a, b string) bool { return a < b }),
); len(diff) != 0 {
t.Errorf("Unexpected RuntimeInfo (-want,+got):\n%s", diff)
}
})
// Validate the entire info object
if diff := cmp.Diff(tc.wantInfo, tc.info,
cmpopts.SortSlices(func(a, b string) bool { return a < b }),
cmpopts.SortMaps(func(a, b string) bool { return a < b }),
ignoreSyncPodSets,
); len(diff) != 0 {
t.Errorf("Unexpected RuntimeInfo (-want,+got):\n%s", diff)
}
})
}
}

func TestValidate(t *testing.T) {
cases := map[string]struct {
info *runtime.Info
Expand Down
20 changes: 13 additions & 7 deletions pkg/runtime/runtime.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,6 @@ import (
"github.com/kubeflow/trainer/v2/pkg/constants"
)

var (
defaultPodSetsSyncer = func(*Info) {}
syncPodSets = defaultPodSetsSyncer
)

type Info struct {
// Labels and Annotations to add to the RuntimeJobTemplate.
Labels map[string]string
Expand All @@ -49,6 +44,9 @@ type Info struct {
// TemplateSpec is TrainingRuntime Template object.
// ObjApply podSpecs and this PodSets should be kept in sync by info.SyncPodSetsToTemplateSpec().
TemplateSpec TemplateSpec
// SyncPodSets is the function to sync PodSets to TemplateSpec.
// This is stored per-instance to avoid data races when RuntimeInfo is called concurrently.
SyncPodSets func(*Info)
}

type RuntimePolicy struct {
Expand Down Expand Up @@ -99,6 +97,7 @@ type InfoOptions struct {
annotations map[string]string
runtimePolicy RuntimePolicy
templateSpec TemplateSpec
syncPodSets func(*Info)
}

type InfoOption func(options *InfoOptions)
Expand Down Expand Up @@ -174,7 +173,7 @@ func toPodSetContainer(containerApply ...corev1ac.ContainerApplyConfiguration) i

func WithPodSetSyncer(syncer func(*Info)) InfoOption {
return func(o *InfoOptions) {
syncPodSets = syncer
o.syncPodSets = syncer
}
}

Expand All @@ -192,6 +191,11 @@ func NewInfo(opts ...InfoOption) *Info {
PodLabels: make(map[string]string),
},
TemplateSpec: options.templateSpec,
SyncPodSets: options.syncPodSets,
}
// Set default no-op syncer if none provided
if info.SyncPodSets == nil {
info.SyncPodSets = func(*Info) {}
}
if options.labels != nil {
info.Labels = options.labels
Expand All @@ -203,7 +207,9 @@ func NewInfo(opts ...InfoOption) *Info {
}

func (i *Info) SyncPodSetsToTemplateSpec() {
syncPodSets(i)
if i.SyncPodSets != nil {
i.SyncPodSets(i)
}
}

func TemplateSpecApply[A any](info *Info) (*A, bool) {
Expand Down
5 changes: 5 additions & 0 deletions pkg/runtime/runtime_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ import (
jobsetplgconsts "github.com/kubeflow/trainer/v2/pkg/runtime/framework/plugins/jobset/constants"
)

// IgnoreSyncPodSets is a cmp option to ignore the SyncPodSets field when comparing Info structs.
// This field is a function pointer that can't be meaningfully compared.
var IgnoreSyncPodSets = cmpopts.IgnoreFields(Info{}, "SyncPodSets")

func TestNewInfo(t *testing.T) {
cases := map[string]struct {
infoOpts []InfoOption
Expand Down Expand Up @@ -431,6 +435,7 @@ func TestNewInfo(t *testing.T) {
}
cmpOpts := []cmp.Option{
cmpopts.SortMaps(func(a, b string) bool { return a < b }),
IgnoreSyncPodSets,
}
for name, tc := range cases {
t.Run(name, func(t *testing.T) {
Expand Down
Loading