Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
273 changes: 201 additions & 72 deletions pkg/plugins/gateway/algorithms/pd_disaggregation.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,15 @@ import (
"encoding/json"
"fmt"
"io"
"math"
"math/rand"
"net/http"
"strconv"
"time"

"github.com/vllm-project/aibrix/pkg/cache"
"github.com/vllm-project/aibrix/pkg/constants"
"github.com/vllm-project/aibrix/pkg/metrics"
"github.com/vllm-project/aibrix/pkg/types"
"github.com/vllm-project/aibrix/pkg/utils"
"github.com/vllm-project/aibrix/pkg/utils/prefixcacheindexer"
Expand All @@ -48,10 +50,15 @@ const (
RoleReplicaIndex string = "stormservice.orchestration.aibrix.ai/role-replica-index"
PodGroupIndex string = "stormservice.orchestration.aibrix.ai/pod-group-index"
defaultPrefillRequestTimeout int = 30

defaultMaxRequest float64 = 32
defaultMaxTokenThroughputDiff float64 = 2048
)

var (
prefillRequestTimeout int = utils.LoadEnvInt("AIBRIX_PREFILL_REQUEST_TIMEOUT", defaultPrefillRequestTimeout)
prefillRequestTimeout int = utils.LoadEnvInt("AIBRIX_PREFILL_REQUEST_TIMEOUT", defaultPrefillRequestTimeout)
aibrixDecodeMaxRequest float64 = utils.LoadEnvFloat("AIBRIX_DECODE_MAX_REQUEST", defaultMaxRequest)
aibrixDecodeMaxThroughputDiff float64 = utils.LoadEnvFloat("AIBRIX_DECODE_MAX_THROUGHPUT", defaultMaxTokenThroughputDiff)
)

func init() {
Expand Down Expand Up @@ -86,31 +93,34 @@ func NewPDRouter() (types.Router, error) {
}

func (r pdRouter) Route(ctx *types.RoutingContext, readyPodList types.PodList) (string, error) {
prefillPods, decodePods, err := r.filterPrefillDecodePods(readyPodList.All())
prefillPod, decodePod, err := r.filterPrefillDecodePods(ctx, readyPodList.All())
if err != nil {
return "", err
}

prefillPod, err := r.doPrefillRequest(ctx, prefillPods, getLLMEngine(prefillPods[0], LLMEngineIdentifier, VLLMEngine))
err = r.doPrefillRequest(ctx, prefillPod, getLLMEngine(prefillPod, LLMEngineIdentifier, VLLMEngine))
if err != nil {
klog.ErrorS(err, "prefill request failed", "request_id", ctx.RequestID)
return "", err
}

decodePod := r.selectDecodePod(prefillPod, decodePods)
if decodePod == nil {
return "", fmt.Errorf("decode pod not found")
}

klog.InfoS("P/D", "prefill_pod", prefillPod.Name, "decode_pod", decodePod.Name)
klog.InfoS("P/D", "request_id", ctx.RequestID, "prefill_pod", prefillPod.Name, "decode_pod", decodePod.Name)

ctx.SetTargetPod(decodePod)
return ctx.TargetAddress(), nil
}

func (r *pdRouter) filterPrefillDecodePods(readyPods []*v1.Pod) ([]*v1.Pod, []*v1.Pod, error) {
type Scores struct {
Pod *v1.Pod
Score float64
}

func (r *pdRouter) filterPrefillDecodePods(routingCtx *types.RoutingContext, readyPods []*v1.Pod) (*v1.Pod, *v1.Pod, error) {
prefillPods, decodePods := []*v1.Pod{}, []*v1.Pod{}
for _, pod := range readyPods {
if _, ok := pod.Labels[PDRoleSetIdentifier]; !ok {
continue
}
if _, ok := pod.Labels[PDRoleIdentifier]; !ok {
continue
}
Expand All @@ -125,92 +135,130 @@ func (r *pdRouter) filterPrefillDecodePods(readyPods []*v1.Pod) ([]*v1.Pod, []*v
decodePods = append(decodePods, pod)
}
}

if len(prefillPods) == 0 || len(decodePods) == 0 {
return nil, nil, fmt.Errorf("prefill or decodes pods are not ready")
}
return prefillPods, decodePods, nil

// Check for prefill and decode imbalance
// TODO: consider prefill/decode imbalance pod by roleset rather than individual pods because in corner case,
// if roleset1 has prefill imbalance and roleset2 has decode imbalance then always prefill/decode will be selected for roleset2
// and make roleset2 decode imbalance worse.
targetPod, isImbalanced := getTargetPodOnLoadImbalance(r.cache, prefillPods)
if isImbalanced {
klog.V(4).InfoS("load imbalance detected, selecting least-loaded prefill pod", "request_id", routingCtx.RequestID, "selected_pod", targetPod.Name)
prefillPods = []*v1.Pod{targetPod}
decodePods = utils.FilterPodsByLabel(decodePods, PDRoleSetIdentifier, targetPod.Labels[PDRoleSetIdentifier])
}
targetPod, maxRequestCount, maxThroughput, maxFreeGPUUsage, podRequestCounts, podThroughputs, podFreeGpuUsage := r.loadImbalanceSelectDecodePod(routingCtx, decodePods)
if targetPod != nil {
klog.V(4).InfoS("load imbalance detected, selecting least-loaded decode pod", "request_id", routingCtx.RequestID, "selected_pod", targetPod.Name)
decodePods = []*v1.Pod{targetPod}
if len(prefillPods) > 1 {
prefillPods = utils.FilterPodsByLabel(prefillPods, PDRoleSetIdentifier, targetPod.Labels[PDRoleSetIdentifier])
}
}

prefillScores, prefixHashes := r.scorePrefillPods(routingCtx, prefillPods)
decodeScores := r.scoreDecodePods(decodePods, maxRequestCount, maxThroughput, maxFreeGPUUsage, podRequestCounts, podThroughputs, podFreeGpuUsage)

var targetPrefillPod, targetDecodePod *v1.Pod
minScore := math.MaxFloat64
for roleset, prefillScore := range prefillScores {
decodeScore, ok := decodeScores[roleset]
if !ok {
continue
}

if prefillScore.Score+decodeScore.Score < minScore {
minScore = prefillScore.Score + decodeScore.Score
targetPrefillPod = prefillScore.Pod
targetDecodePod = decodeScore.Pod
}
}
Comment on lines +164 to +177
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

After this loop, targetPrefillPod or targetDecodePod could be nil if prefillScores is empty or no matching rolesets are found in decodeScores. This will lead to a nil pointer dereference in the defer block on line 181 and when prefillPod is used in the Route function on line 101. You should add a check to ensure both pods have been selected.

var targetPrefillPod, targetDecodePod *v1.Pod
	minScore := math.MaxFloat64
	for roleset, prefillScore := range prefillScores {
		decodeScore, ok := decodeScores[roleset]
		if !ok {
			continue
		}

		if prefillScore.Score+decodeScore.Score < minScore {
			minScore = prefillScore.Score + decodeScore.Score
			targetPrefillPod = prefillScore.Pod
			targetDecodePod = decodeScore.Pod
		}
	}

	if targetPrefillPod == nil || targetDecodePod == nil {
		return nil, nil, fmt.Errorf("failed to select a pair of prefill and decode pods")
	}


defer func() {
if len(prefixHashes) > 0 {
r.prefixCacheIndexer.AddPrefix(prefixHashes, routingCtx.Model, targetPrefillPod.Name)
}
}()

return targetPrefillPod, targetDecodePod, nil
}

func (r *pdRouter) evaluatePrefixCache(ctx *types.RoutingContext, prefillPods []*v1.Pod) (*v1.Pod, []uint64, error) {
tokens, err := r.tokenizer.TokenizeInputText(ctx.Message)
func (r *pdRouter) scorePrefillPods(routingCtx *types.RoutingContext, prefillPods []*v1.Pod) (map[string]*Scores, []uint64) {
prefillScores := map[string]*Scores{}
tokens, err := r.tokenizer.TokenizeInputText(routingCtx.Message)
if err != nil {
return nil, nil, err
return nil, nil
}

readyPodsMap := map[string]struct{}{}
for _, pod := range prefillPods {
readyPodsMap[pod.Name] = struct{}{}
}
matchedPods, prefixHashes := r.prefixCacheIndexer.MatchPrefix(tokens, ctx.Model, readyPodsMap)
matchedPods, prefixHashes := r.prefixCacheIndexer.MatchPrefix(tokens, routingCtx.Model, readyPodsMap)

var prefillPod *v1.Pod
// check for load imbalance first
targetPod, isImbalanced := getTargetPodOnLoadImbalance(r.cache, prefillPods)
if isImbalanced {
klog.InfoS("load imbalance detected, selecting least-loaded prefill pod",
"request_id", ctx.RequestID, "selected_pod", targetPod.Name)
prefillPod = targetPod
} else if len(matchedPods) > 0 {
prefillPod = getTargetPodFromMatchedPods(r.cache, prefillPods, matchedPods)
}
if prefillPod == nil {
prefillPod, err = utils.SelectRandomPod(prefillPods, rand.Intn)
if err == nil {
klog.V(4).InfoS("fallback to random prefill pod selection",
"request_id", ctx.RequestID,
"selected_pod", prefillPod.Name)
var maxRequestCount float64 = 0
requestCount := []float64{}
podRequestCount := getRequestCounts(r.cache, prefillPods)
for _, cnt := range podRequestCount {
countFloat := float64(cnt)
requestCount = append(requestCount, countFloat)
if countFloat > maxRequestCount {
maxRequestCount = countFloat
}
}

return prefillPod, prefixHashes, err
}
meanRequestCount := mean(requestCount)
stdDevRequestCount := standardDeviation(requestCount)

func (r *pdRouter) selectDecodePod(prefillPod *v1.Pod, decodePods []*v1.Pod) *v1.Pod {
prefillRoleSet, ok := prefillPod.Labels[PDRoleSetIdentifier]
if !ok {
return nil
}

filteredDecodePods := []*v1.Pod{}
for _, pod := range decodePods {
if podRoleSet, exists := pod.Labels[PDRoleSetIdentifier]; exists && podRoleSet == prefillRoleSet {
filteredDecodePods = append(filteredDecodePods, pod)
for _, pod := range prefillPods {
rolesetName := pod.Labels[PDRoleSetIdentifier]
reqCnt := float64(podRequestCount[pod.Name])
if reqCnt > meanRequestCount+float64(standardDeviationFactor)*stdDevRequestCount {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The constant standardDeviationFactor is used here but is not defined in the file, which will cause a compilation error. Please define it in the const block at the top of the file. For example: const standardDeviationFactor = 2.0.

continue
}
}
if len(filteredDecodePods) == 0 {
return nil
}

// prefer decode pod with least running requests
decodePod := selectPodWithLeastRequestCount(r.cache, filteredDecodePods)
if decodePod != nil {
klog.V(5).InfoS("selected decode pod by least request count",
"prefill_pod", prefillPod.Name,
"decode_pod", decodePod.Name)
return decodePod
prefillScore := (100-float64(matchedPods[pod.Name]))*.1 + (reqCnt / maxRequestCount)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There is a potential for division by zero here. If all prefill pods are idle, getRequestCounts will return counts of 0 for all of them, making maxRequestCount equal to 0. This will cause a panic. You should guard against this.

if maxRequestCount == 0 {
			maxRequestCount = 1 // Avoid division by zero
		}
		prefillScore := (100-float64(matchedPods[pod.Name]))*.1 + (reqCnt / maxRequestCount)

if existingScore, exists := prefillScores[rolesetName]; !exists || prefillScore < existingScore.Score {
prefillScores[rolesetName] = &Scores{
Pod: pod,
Score: prefillScore,
}
}
}

// fallback: random selection pods
decodePod, _ = utils.SelectRandomPod(filteredDecodePods, rand.Intn)
return decodePod
return prefillScores, prefixHashes
}

func (r *pdRouter) doPrefillRequest(routingCtx *types.RoutingContext, prefillPods []*v1.Pod, llmEngine string) (*v1.Pod, error) {
prefillPod, prefixHashes, err := r.evaluatePrefixCache(routingCtx, prefillPods)
if err != nil {
return nil, err
}
defer func() {
if len(prefixHashes) > 0 {
r.prefixCacheIndexer.AddPrefix(prefixHashes, routingCtx.Model, prefillPod.Name)
func (r *pdRouter) scoreDecodePods(filteredDecodePods []*v1.Pod,
maxRequestCount float64, maxThroughput float64, maxFreeGPUUsage float64,
podRequestCounts map[string]float64, podThroughputs map[string]float64, podFreeGpuUsage map[string]float64) map[string]*Scores {
decodeScores := map[string]*Scores{}

for _, pod := range filteredDecodePods {
rolesetName := pod.Labels[PDRoleSetIdentifier]
normalizedRunningReqs := podRequestCounts[pod.Name] / maxRequestCount
normalizedThroughput := 1 - podThroughputs[pod.Name]/maxThroughput
normalizedFreeGPUPercent := podFreeGpuUsage[pod.Name] / maxFreeGPUUsage

decodeScore := ((normalizedRunningReqs) + normalizedThroughput) / normalizedFreeGPUPercent
if existingScore, exists := decodeScores[rolesetName]; !exists || decodeScore < existingScore.Score {
decodeScores[rolesetName] = &Scores{
Pod: pod,
Score: decodeScore,
}
}
}()
}

return decodeScores
}

func (r *pdRouter) doPrefillRequest(routingCtx *types.RoutingContext, prefillPod *v1.Pod, llmEngine string) error {
// Prepare prefill request payload
payload, err := r.preparePrefillPayload(routingCtx, prefillPod, llmEngine)
if err != nil {
return nil, fmt.Errorf("failed to prepare prefill payload: %w", err)
return fmt.Errorf("failed to prepare prefill payload: %w", err)
}

// Execute HTTP request
Expand All @@ -235,23 +283,23 @@ func (r *pdRouter) doPrefillRequest(routingCtx *types.RoutingContext, prefillPod
} else if llmEngine == VLLMEngine {
responseData, err := r.executeHTTPRequest(apiURL, routingCtx, payload)
if err != nil {
return nil, fmt.Errorf("failed to execute prefill request: %w", err)
return fmt.Errorf("failed to execute prefill request: %w", err)
}

// Update routing context with KV transfer params from prefill response for vLLM
if err := r.updateRoutingContextWithKVTransferParams(routingCtx, responseData, prefillPod); err != nil {
return nil, fmt.Errorf("failed to update routing context with KV transfer params: %w", err)
return fmt.Errorf("failed to update routing context with KV transfer params: %w", err)
}

klog.InfoS("prefill_request_complete", "request_id", routingCtx.RequestID, "prefill_pod_ip", prefillPod.Status.PodIP)
} else {
if _, err := r.executeHTTPRequest(apiURL, routingCtx, payload); err != nil {
return nil, fmt.Errorf("failed to execute prefill request: %w", err)
return fmt.Errorf("failed to execute prefill request: %w", err)
}
klog.InfoS("prefill_request_complete", "request_id", routingCtx.RequestID, "prefill_pod_ip", prefillPod.Status.PodIP)
}

return prefillPod, nil
return nil
}

func (r *pdRouter) preparePrefillPayload(routingCtx *types.RoutingContext, pod *v1.Pod, llmEngine string) ([]byte, error) {
Expand Down Expand Up @@ -397,3 +445,84 @@ func getSGLangBootstrapPort(pod *v1.Pod) int64 {
}
return SGLangBootstrapPort // Default port
}

func (r *pdRouter) loadImbalanceSelectDecodePod(ctx *types.RoutingContext, filteredDecodePods []*v1.Pod) (*v1.Pod, float64, float64, float64, map[string]float64, map[string]float64, map[string]float64) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This function returns 7 values, which can be difficult to manage and is considered a code smell in Go. Consider refactoring this to return a struct containing these values. This will improve readability and maintainability, and reduce the risk of errors like swapped return values.

podRequestCounts := make(map[string]float64)
podThroughputs := make(map[string]float64)
podFreeGpuUsage := make(map[string]float64)

minRequestPod := filteredDecodePods[0]
minRequestCount := math.MaxFloat64
maxRequestCount := float64(0)

minThroughputPod := filteredDecodePods[0]
minThroughput := float64(math.MaxFloat64)
maxThroughput := float64(0)

minFreeGPUUsage := float64(math.MaxFloat64)
maxFreeGPUUsage := float64(0)

for _, pod := range filteredDecodePods {
runningReqs, err := r.cache.GetMetricValueByPod(pod.Name, pod.Namespace, metrics.RealtimeNumRequestsRunning)
if err != nil {
runningReqs = &metrics.SimpleMetricValue{Value: 0}
}
requestCount := runningReqs.GetSimpleValue()
podRequestCounts[pod.Name] = requestCount
if requestCount < minRequestCount {
minRequestCount = requestCount
minRequestPod = pod
}
maxRequestCount = math.Max(maxRequestCount, requestCount)

tokenThroughput, err := r.cache.GetMetricValueByPodModel(pod.Name, pod.Namespace, ctx.Model, metrics.AvgGenerationThroughputToksPerS)
if err != nil {
tokenThroughput = &metrics.SimpleMetricValue{Value: 0}
}
throughput := tokenThroughput.GetSimpleValue()
podThroughputs[pod.Name] = throughput
if throughput < minThroughput {
minThroughput = throughput
minThroughputPod = pod
}
maxThroughput = math.Max(maxThroughput, throughput)

gpuUsage, err := r.cache.GetMetricValueByPodModel(pod.Name, pod.Namespace, ctx.Model, metrics.GPUCacheUsagePerc)
if err != nil {
gpuUsage = &metrics.SimpleMetricValue{Value: 0}
}
podFreeGpuUsage[pod.Name] = 100 - gpuUsage.GetSimpleValue()*100
if podFreeGpuUsage[pod.Name] <= 0 {
podFreeGpuUsage[pod.Name] = 0.1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The magic number 0.1 is used here to avoid potential division-by-zero issues later. It's better to define this as a named constant to improve code clarity and maintainability. For example: const minFreeGPUUsageEpsilon = 0.1.

podFreeGpuUsage[pod.Name] = 0.1 // TODO: use a const

}
minFreeGPUUsage = math.Min(minFreeGPUUsage, podFreeGpuUsage[pod.Name])
maxFreeGPUUsage = math.Max(maxFreeGPUUsage, podFreeGpuUsage[pod.Name])
}

if minRequestCount == 0 || maxRequestCount-minRequestCount >= aibrixDecodeMaxRequest {
klog.InfoS("REQUEST_SELECTED_DECODE_POD", "request_id", ctx.RequestID,
"min_request_count", minRequestCount, "max_request_count", maxRequestCount,
"min_throughput", minThroughput, "max_throughput", maxThroughput,
"free_gpu_percent", podFreeGpuUsage[minRequestPod.Name],
"decode_pod", minRequestPod.Name)
return minRequestPod, maxRequestCount, maxFreeGPUUsage, maxThroughput, podRequestCounts, podThroughputs, podFreeGpuUsage
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The order of the returned values maxFreeGPUUsage and maxThroughput is swapped compared to the function signature. The signature expects (..., maxThroughput, maxFreeGPUUsage, ...) but this returns (..., maxFreeGPUUsage, maxThroughput, ...). This will cause incorrect values to be used in the scoring logic.

return minRequestPod, maxRequestCount, maxThroughput, maxFreeGPUUsage, podRequestCounts, podThroughputs, podFreeGpuUsage

}

if maxThroughput-minThroughput > aibrixDecodeMaxThroughputDiff {
klog.InfoS("THROUGHPUT_SELECTED_DECODE_POD", "request_id", ctx.RequestID,
"min_request_count", minRequestCount, "max_request_count", maxRequestCount,
"min_throughput", minThroughput, "max_throughput", maxThroughput,
"free_gpu_percent", podFreeGpuUsage[minThroughputPod.Name],
"decode_pod", minThroughputPod.Name)
return minThroughputPod, maxRequestCount, maxFreeGPUUsage, maxThroughput, podRequestCounts, podThroughputs, podFreeGpuUsage
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similar to the other return path in this function, the order of the returned values maxFreeGPUUsage and maxThroughput is swapped compared to the function signature. This will lead to incorrect scoring calculations.

return minThroughputPod, maxRequestCount, maxThroughput, maxFreeGPUUsage, podRequestCounts, podThroughputs, podFreeGpuUsage

}

if maxRequestCount == 0 {
maxRequestCount = 1
}
if maxThroughput == 0 {
maxThroughput = 1
}

return nil, maxRequestCount, maxThroughput, maxFreeGPUUsage, podRequestCounts, podThroughputs, podFreeGpuUsage
}
11 changes: 11 additions & 0 deletions pkg/utils/pod.go
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,17 @@ func FilterPodByName(podname string, pods []*v1.Pod) (*v1.Pod, bool) {
return nil, false
}

// FilterPodsByLabel filters pods that have a specific label key-value pair
func FilterPodsByLabel(pods []*v1.Pod, labelKey, labelValue string) []*v1.Pod {
var filtered []*v1.Pod
for _, pod := range pods {
if value, exists := pod.Labels[labelKey]; exists && value == labelValue {
filtered = append(filtered, pod)
}
}
return filtered
}

// DeploymentNameFromPod extracts the deployment name from the pod using two methods:
// 1. If the pod has a label with the key "app.kubernetes.io/name", its value is considered the deployment name.
// 2. If the pod has an owner reference of kind "ReplicaSet", the deployment name is extracted from the owner reference's name.
Expand Down
Loading