Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions pkg/runtime/framework/core/framework_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ import (
testingutil "github.com/kubeflow/trainer/v2/pkg/util/testing"
)

var ignoreSyncPodSets = cmpopts.IgnoreFields(runtime.Info{}, "SyncPodSets")

// TODO: We should introduce mock plugins and use plugins in this framework testing.
// After we migrate the actual plugins to mock one for testing data,
// we can delegate the actual plugin testing to each plugin directories, and implement detailed unit testing.
Expand Down Expand Up @@ -329,7 +331,7 @@ func TestRunEnforceMLPolicyPlugins(t *testing.T) {
if diff := cmp.Diff(tc.wantError, err, cmpopts.EquateErrors()); len(diff) != 0 {
t.Errorf("Unexpected error (-want,+got): %s", diff)
}
if diff := cmp.Diff(tc.wantRuntimeInfo, tc.runtimeInfo, cmpopts.EquateEmpty()); len(diff) != 0 {
if diff := cmp.Diff(tc.wantRuntimeInfo, tc.runtimeInfo, cmpopts.EquateEmpty(), ignoreSyncPodSets); len(diff) != 0 {
t.Errorf("Unexpected runtime.Info (-want,+got): %s", diff)
}
})
Expand Down Expand Up @@ -424,7 +426,7 @@ func TestRunEnforcePodGroupPolicyPlugins(t *testing.T) {
if diff := cmp.Diff(tc.wantError, err, cmpopts.EquateErrors()); len(diff) != 0 {
t.Errorf("Unexpected error (-want,+got): %s", diff)
}
if diff := cmp.Diff(tc.wantRuntimeInfo, tc.runtimeInfo); len(diff) != 0 {
if diff := cmp.Diff(tc.wantRuntimeInfo, tc.runtimeInfo, ignoreSyncPodSets); len(diff) != 0 {
t.Errorf("Unexpected runtime.Info (-want,+got): %s", diff)
}
})
Expand Down Expand Up @@ -2215,7 +2217,7 @@ test-job-node-0-1.test-job slots=1
t.Errorf("Unexpected errors (-want,+got):\n%s", diff)
}

if diff := cmp.Diff(tc.wantRuntimeInfo, tc.runtimeInfo, testingutil.PodSetEndpointsCmpOpts); len(diff) != 0 {
if diff := cmp.Diff(tc.wantRuntimeInfo, tc.runtimeInfo, testingutil.PodSetEndpointsCmpOpts, ignoreSyncPodSets); len(diff) != 0 {
t.Errorf("Unexpected runtime.Info (-want,+got)\n%s", diff)
}

Expand Down Expand Up @@ -2617,7 +2619,7 @@ func TestPodNetworkPlugins(t *testing.T) {
if diff := cmp.Diff(tc.wantError, err); len(diff) != 0 {
t.Errorf("Unexpected error (-want,+got):\n%s", diff)
}
if diff := cmp.Diff(tc.wantRuntimeInfo, tc.runtimeInfo, testingutil.PodSetEndpointsCmpOpts); len(diff) != 0 {
if diff := cmp.Diff(tc.wantRuntimeInfo, tc.runtimeInfo, testingutil.PodSetEndpointsCmpOpts, ignoreSyncPodSets); len(diff) != 0 {
t.Errorf("Unexpected runtimeInfo (-want,+got):\n%s", diff)
}
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ import (
utiltesting "github.com/kubeflow/trainer/v2/pkg/util/testing"
)

var ignoreSyncPodSets = cmpopts.IgnoreFields(runtime.Info{}, "SyncPodSets")

func TestCoScheduling(t *testing.T) {
objCmpOpts := []gocmp.Option{
cmpopts.SortSlices(func(a, b apiruntime.Object) int {
Expand Down Expand Up @@ -565,10 +567,10 @@ func TestCoScheduling(t *testing.T) {
if diff := gocmp.Diff(tc.wantInfo, tc.info,
cmpopts.SortSlices(func(a, b string) bool { return a < b }),
cmpopts.SortMaps(func(a, b int) bool { return a < b }),
ignoreSyncPodSets,
); len(diff) != 0 {
t.Errorf("Unexpected info from EnforcePodGroupPolicy (-want,+got):\n%s", diff)
}

var objs []any
objs, err = plugin.(framework.ComponentBuilderPlugin).Build(ctx, tc.info, tc.trainJob)
if diff := gocmp.Diff(tc.wantBuildError, err, cmpopts.EquateErrors()); len(diff) != 0 {
Expand Down
3 changes: 3 additions & 0 deletions pkg/runtime/framework/plugins/jobset/jobset_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ import (
utiltesting "github.com/kubeflow/trainer/v2/pkg/util/testing"
)

var ignoreSyncPodSets = cmpopts.IgnoreFields(runtime.Info{}, "SyncPodSets")

// TODO: Add tests for all Interfaces.
// REF: https://github.com/kubeflow/trainer/issues/2468

Expand Down Expand Up @@ -327,6 +329,7 @@ func TestJobSet(t *testing.T) {
cmpopts.SortSlices(func(a, b string) bool { return a < b }),
cmpopts.SortMaps(func(a, b string) bool { return a < b }),
utiltesting.PodSetEndpointsCmpOpts,
ignoreSyncPodSets,
); len(diff) != 0 {
t.Errorf("Unexpected Info from IdentifyPodNetwork (-want,+got):\n%s", diff)
}
Expand Down
3 changes: 3 additions & 0 deletions pkg/runtime/framework/plugins/mpi/mpi_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ import (
utiltesting "github.com/kubeflow/trainer/v2/pkg/util/testing"
)

var ignoreSyncPodSets = cmpopts.IgnoreFields(runtime.Info{}, "SyncPodSets")

func TestMPI(t *testing.T) {
objCmpOpts := []gocmp.Option{
cmpopts.SortSlices(func(a, b apiruntime.Object) int {
Expand Down Expand Up @@ -860,6 +862,7 @@ trainJob-node-1-0.trainJob slots=1
cmpopts.SortSlices(func(a, b string) bool { return a < b }),
cmpopts.SortMaps(func(a, b int) bool { return a < b }),
utiltesting.PodSetEndpointsCmpOpts,
ignoreSyncPodSets,
); len(diff) != 0 {
t.Errorf("Unexpected info from EnforceMLPolicy (-want, +got): %s", diff)
}
Expand Down
3 changes: 3 additions & 0 deletions pkg/runtime/framework/plugins/plainml/plainml_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ import (
utiltesting "github.com/kubeflow/trainer/v2/pkg/util/testing"
)

var ignoreSyncPodSets = cmpopts.IgnoreFields(runtime.Info{}, "SyncPodSets")

func TestPlainML(t *testing.T) {
cases := map[string]struct {
trainJob *trainer.TrainJob
Expand Down Expand Up @@ -183,6 +185,7 @@ func TestPlainML(t *testing.T) {
if diff := cmp.Diff(tc.wantInfo, tc.info,
cmpopts.SortSlices(func(a, b string) bool { return a < b }),
cmpopts.SortMaps(func(a, b string) bool { return a < b }),
ignoreSyncPodSets,
); len(diff) != 0 {
t.Errorf("Unexpected RuntimeInfo (-want,+got):\n%s", diff)
}
Expand Down
3 changes: 3 additions & 0 deletions pkg/runtime/framework/plugins/torch/torch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ import (
)

func TestTorch(t *testing.T) {
var ignoreSyncPodSets = cmpopts.IgnoreFields(runtime.Info{}, "SyncPodSets")

cases := map[string]struct {
info *runtime.Info
trainJob *trainer.TrainJob
Expand Down Expand Up @@ -1402,6 +1404,7 @@ func TestTorch(t *testing.T) {
if diff := cmp.Diff(tc.wantInfo, tc.info,
cmpopts.SortSlices(func(a, b string) bool { return a < b }),
cmpopts.SortMaps(func(a, b string) bool { return a < b }),
ignoreSyncPodSets,
); len(diff) != 0 {
t.Errorf("Unexpected RuntimeInfo (-want,+got):\n%s", diff)
}
Expand Down
20 changes: 13 additions & 7 deletions pkg/runtime/runtime.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,6 @@ import (
"github.com/kubeflow/trainer/v2/pkg/constants"
)

var (
defaultPodSetsSyncer = func(*Info) {}
syncPodSets = defaultPodSetsSyncer
)

type Info struct {
// Labels and Annotations to add to the RuntimeJobTemplate.
Labels map[string]string
Expand All @@ -49,6 +44,9 @@ type Info struct {
// TemplateSpec is TrainingRuntime Template object.
// ObjApply podSpecs and this PodSets should be kept in sync by info.SyncPodSetsToTemplateSpec().
TemplateSpec TemplateSpec
// SyncPodSets is the function to sync PodSets to TemplateSpec.
// This is stored per-instance to avoid data races when RuntimeInfo is called concurrently.
SyncPodSets func(*Info)
}

type RuntimePolicy struct {
Expand Down Expand Up @@ -99,6 +97,7 @@ type InfoOptions struct {
annotations map[string]string
runtimePolicy RuntimePolicy
templateSpec TemplateSpec
syncPodSets func(*Info)
}

type InfoOption func(options *InfoOptions)
Expand Down Expand Up @@ -174,7 +173,7 @@ func toPodSetContainer(containerApply ...corev1ac.ContainerApplyConfiguration) i

func WithPodSetSyncer(syncer func(*Info)) InfoOption {
return func(o *InfoOptions) {
syncPodSets = syncer
o.syncPodSets = syncer
}
}

Expand All @@ -192,6 +191,11 @@ func NewInfo(opts ...InfoOption) *Info {
PodLabels: make(map[string]string),
},
TemplateSpec: options.templateSpec,
SyncPodSets: options.syncPodSets,
}
// Set default no-op syncer if none provided
if info.SyncPodSets == nil {
info.SyncPodSets = func(*Info) {}
}
if options.labels != nil {
info.Labels = options.labels
Expand All @@ -203,7 +207,9 @@ func NewInfo(opts ...InfoOption) *Info {
}

func (i *Info) SyncPodSetsToTemplateSpec() {
syncPodSets(i)
if i.SyncPodSets != nil {
i.SyncPodSets(i)
}
}

func TemplateSpecApply[A any](info *Info) (*A, bool) {
Expand Down
4 changes: 4 additions & 0 deletions pkg/runtime/runtime_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ import (
)

func TestNewInfo(t *testing.T) {
// IgnoreSyncPodSets is a cmp option to ignore the SyncPodSets field when comparing Info structs.
// This field is a function pointer that can't be meaningfully compared.
var IgnoreSyncPodSets = cmpopts.IgnoreFields(Info{}, "SyncPodSets")
cases := map[string]struct {
infoOpts []InfoOption
wantInfo *Info
Expand Down Expand Up @@ -431,6 +434,7 @@ func TestNewInfo(t *testing.T) {
}
cmpOpts := []cmp.Option{
cmpopts.SortMaps(func(a, b string) bool { return a < b }),
IgnoreSyncPodSets,
}
for name, tc := range cases {
t.Run(name, func(t *testing.T) {
Expand Down
Loading