Skip to content

Commit 52fae2d

Browse files
committed
feat(scheduler):add support for kai scheduler
Signed-off-by: Harshal292004 <[email protected]>
1 parent b71a690 commit 52fae2d

File tree

1 file changed

+65
-0
lines changed

1 file changed

+65
-0
lines changed
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
package kai
2+
3+
import (
4+
"context"
5+
6+
"github.com/go-logr/logr"
7+
"k8s.io/apimachinery/pkg/api/meta"
8+
apiruntime "k8s.io/apimachinery/pkg/runtime"
9+
ctrl "sigs.k8s.io/controller-runtime"
10+
"sigs.k8s.io/controller-runtime/pkg/client"
11+
12+
trainer "github.com/kubeflow/trainer/pkg/apis/trainer/v1alpha1"
13+
"github.com/kubeflow/trainer/pkg/constants"
14+
"github.com/kubeflow/trainer/pkg/runtime"
15+
"github.com/kubeflow/trainer/pkg/runtime/framework"
16+
schedulerpluginsv1alpha1 "sigs.k8s.io/scheduler-plugins/apis/scheduling/v1alpha1"
17+
)
18+
19+
type KAIScheduler struct {
20+
client client.Client
21+
restMapper meta.RESTMapper
22+
scheme *apiruntime.Scheme
23+
logger logr.Logger
24+
}
25+
26+
// Implementing interfaces required for GangScheduling
27+
var _ framework.EnforcePodGroupPolicyPlugin = (*KAIScheduler)(nil)
28+
var _ framework.WatchExtensionPlugin = (*KAIScheduler)(nil)
29+
var _ framework.ComponentBuilderPlugin = (*KAIScheduler)(nil)
30+
31+
const Name = "KAIScheduler"
32+
33+
func New(ctx context.Context, client client.Client, indexer client.FieldIndexer) (framework.Plugin, error) {
34+
// No need of indexing for KAI
35+
return &KAIScheduler{
36+
client: client,
37+
restMapper: client.RESTMapper(),
38+
scheme: client.Scheme(),
39+
logger: ctrl.LoggerFrom(ctx).WithValues("pluginName", constants.JobSetKind),
40+
}, nil
41+
}
42+
43+
func (k *KAIScheduler) Name() string {
44+
return Name
45+
}
46+
47+
func (k *KAIScheduler) EnforcePodGroupPolicy(info *runtime.Info, trainJob *trainer.TrainJob) error {
48+
if info == nil || info.RuntimePolicy.PodGroupPolicy == nil || trainJob == nil {
49+
return nil
50+
}
51+
52+
if info.Scheduler.PodLabels == nil {
53+
info.Scheduler.PodLabels = make(map[string]string, 1)
54+
}
55+
info.Scheduler.PodLabels[schedulerpluginsv1alpha1.PodGroupLabel] = trainJob.Name
56+
return nil
57+
}
58+
59+
func (k *KAIScheduler) Build(ctx context.Context, info *runtime.Info, trainJob *trainer.TrainJob) ([]any, error) {
60+
return []any{}, nil
61+
}
62+
63+
func (k *KAIScheduler) ReconcilerBuilders() []runtime.ReconcilerBuilder {
64+
return []runtime.ReconcilerBuilder{}
65+
}

0 commit comments

Comments
 (0)