Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
5d32ae2
use `copy_for_forward`
hnyls2002 Oct 4, 2025
e727062
try `run_batch_exp`
hnyls2002 Oct 4, 2025
2553bdb
tiny fix typo
hnyls2002 Oct 4, 2025
15cff1b
wait for default stream
hnyls2002 Oct 4, 2025
5728f0b
try fix pd-disagg and idle batch
hnyls2002 Oct 4, 2025
1ea1338
add todo comments
hnyls2002 Oct 4, 2025
4e6cccd
tiny skip-sample adjust
hnyls2002 Oct 5, 2025
7905df0
minor style adjust
hnyls2002 Oct 5, 2025
58d04c4
disable prefill only optimizations in spec
hnyls2002 Oct 5, 2025
78f1719
fix
hnyls2002 Oct 5, 2025
d530623
update
hnyls2002 Oct 5, 2025
8f2f2e7
Merge branch 'main' into lsyin/remove-overlap-thread
hnyls2002 Oct 5, 2025
d61df9b
Merge branch 'lsyin/tiny-skip-sample-adjust' into lsyin/remove-overla…
hnyls2002 Oct 5, 2025
5083846
try fix grammar sync
hnyls2002 Oct 5, 2025
a5a7159
fix prefill only and delay sample when eagle
hnyls2002 Oct 5, 2025
a42bd28
Merge branch 'main' into lsyin/remove-overlap-thread
hnyls2002 Oct 5, 2025
0edfa65
try fix data overwirte by cuda buffer (introduce real copy done later)
hnyls2002 Oct 5, 2025
d04ae3c
Merge branch 'main' into lsyin/remove-overlap-thread
hnyls2002 Oct 6, 2025
320571a
Merge branch 'main' into lsyin/remove-overlap-thread
hnyls2002 Oct 6, 2025
0d5ec64
Merge branch 'main' into lsyin/remove-overlap-thread
zhyncs Oct 6, 2025
dda6db3
tiny fix command show
hnyls2002 Oct 6, 2025
85b6c80
add copy done function
hnyls2002 Oct 6, 2025
3ba61b1
Merge branch 'main' into lsyin/remove-overlap-thread
hnyls2002 Oct 6, 2025
b5b6b4c
remove launch_done
hnyls2002 Oct 6, 2025
1c1973b
fix ascend
hnyls2002 Oct 6, 2025
d623726
remove forward batch output
hnyls2002 Oct 6, 2025
76d57fc
rename
hnyls2002 Oct 6, 2025
5faf1eb
keep reference of launched batch
hnyls2002 Oct 6, 2025
a10f20b
fix pd-disagg
hnyls2002 Oct 6, 2025
b46991e
fix ascend
hnyls2002 Oct 7, 2025
d649de8
Merge branch 'main' into lsyin/remove-overlap-thread
hnyls2002 Oct 7, 2025
c28121e
Merge branch 'main' into lsyin/remove-overlap-thread
hnyls2002 Oct 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 124 additions & 5 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,10 @@
UpdateWeightsFromTensorReqInput,
)
from sglang.srt.managers.mm_utils import init_embedding_cache
from sglang.srt.managers.overlap_utils import FutureMap
from sglang.srt.managers.schedule_batch import (
FINISH_ABORT,
ModelWorkerBatch,
MultimodalInputs,
Req,
RequestStage,
Expand Down Expand Up @@ -210,6 +212,9 @@ class GenerationBatchResult:
extend_input_len_per_req: List[int]
extend_logprob_start_len_per_req: List[int]

# For overlap scheduling
copy_done: Optional[torch.cuda.Event] = None

@classmethod
def from_forward_batch_output(
cls,
Expand All @@ -226,6 +231,7 @@ def from_forward_batch_output(
extend_input_len_per_req=extend_input_len_per_req,
extend_logprob_start_len_per_req=extend_logprob_start_len_per_req,
can_run_cuda_graph=forward_batch_output.can_run_cuda_graph,
copy_done=forward_batch_output.copy_done,
)

@classmethod
Expand Down Expand Up @@ -386,12 +392,8 @@ def __init__(
logger.info("Overlap scheduler is disabled for embedding models.")

# Launch a tensor parallel worker
if self.enable_overlap:
TpWorkerClass = TpModelWorkerClient
else:
TpWorkerClass = TpModelWorker

self.tp_worker = TpWorkerClass(
self.tp_worker = TpModelWorker(
server_args=server_args,
gpu_id=gpu_id,
tp_rank=tp_rank,
Expand Down Expand Up @@ -616,6 +618,9 @@ def __init__(
# Init prefill kv split size when deterministic inference is enabled with various attention backends
self.init_deterministic_inference_config()

# Init overlap
self.init_overlap()

# Init request dispatcher
self._request_dispatcher = TypeBasedDispatcher(
[
Expand Down Expand Up @@ -928,6 +933,21 @@ def init_disaggregation(self):
# The prefill requests that are in the middle of kv sending
self.disagg_prefill_inflight_queue: List[Req] = []

def init_overlap(self):
if not self.enable_overlap:
return

self.forward_stream = torch.get_device_module(self.device).Stream()
self.forward_stream_ctx = torch.get_device_module(self.device).stream(
self.forward_stream
)
self.copy_stream = torch.get_device_module(self.device).Stream()
self.copy_stream_ctx = torch.get_device_module(self.device).stream(
self.copy_stream
)

self.future_map = FutureMap(self.max_running_requests, self.device)

def init_moe_config(self):
if hasattr(self.model_config.hf_config, "num_experts_per_tok"):
initialize_moe_config(self.server_args)
Expand Down Expand Up @@ -2031,10 +2051,109 @@ def update_running_batch(self, batch: ScheduleBatch) -> Optional[ScheduleBatch]:
batch.prepare_for_decode()
return batch

def run_batch_exp(
self, batch: ScheduleBatch
) -> Union[GenerationBatchResult, EmbeddingBatchResult]:
"""Run a batch."""
self.forward_ct += 1

# Whether to run the profiler
self._profile_batch_predicate(batch)
if self.forward_sleep_time is not None:
logger.info(f"Scheduler.run_batch sleep {self.forward_sleep_time}s")
time.sleep(self.forward_sleep_time)

# Run forward
if self.is_generation:

batch_or_worker_batch = batch

if self.spec_algorithm.is_none():
# FIXME(lsyin): remove this if and finally unify the abstraction
batch_or_worker_batch = batch.get_model_worker_batch()

if self.enable_overlap:
# FIXME: remove this assert
assert isinstance(batch_or_worker_batch, ModelWorkerBatch)
model_worker_batch = batch_or_worker_batch

# Sampling info will be modified during forward
model_worker_batch.sampling_info = self.tp_worker.cur_sampling_info = (
model_worker_batch.sampling_info.copy_for_forward()
)

bs = len(model_worker_batch.seq_lens)
cur_future_map_ct = self.future_map.update_ct(bs)

with self.forward_stream_ctx:
self.future_map.resolve_future(model_worker_batch)
forward_batch_output = self.model_worker.forward_batch_generation(
batch_or_worker_batch
)
next_token_ids = forward_batch_output.next_token_ids
self.future_map.store_to_map(cur_future_map_ct, bs, next_token_ids)

copy_done = torch.cuda.Event()
copy_done.record()

# FIXME(lsyin): move copy_done elsewhere
forward_batch_output.copy_done = copy_done

# FIXME(lsyin): move this assignment elsewhere
forward_batch_output.next_token_ids = (
self.future_map.update_next_future(cur_future_map_ct, bs)
)
else:
forward_batch_output = self.model_worker.forward_batch_generation(
batch_or_worker_batch
)
copy_done = None

if not self.spec_algorithm.is_none():
# TODO(lsyin): unify this metric-updating logic with non-spec, and move it to decode processing
self.udpate_spec_metrics(
batch.batch_size(), forward_batch_output.num_accepted_tokens
)

# update batch's output ids
batch.output_ids = forward_batch_output.next_token_ids

# print(f"[Run Batch]: {batch.seq_lens_cpu=}")
# print(f"[Run Batch]: {batch.input_ids=}")
# print(f"[Output Ids]: {batch.output_ids}")

# These 2 values are needed for processing the output, but the values can be
# modified by overlap schedule. So we have to copy them here so that
# we can use the correct values in output processing.
if batch.return_logprob or self.spec_algorithm.is_eagle():
extend_input_len_per_req = [req.extend_input_len for req in batch.reqs]
else:
extend_input_len_per_req = None

if batch.return_logprob:
extend_logprob_start_len_per_req = [
req.extend_logprob_start_len for req in batch.reqs
]
else:
extend_logprob_start_len_per_req = None

return GenerationBatchResult.from_forward_batch_output(
forward_batch_output=forward_batch_output,
extend_input_len_per_req=extend_input_len_per_req,
extend_logprob_start_len_per_req=extend_logprob_start_len_per_req,
)
else: # embedding or reward model
model_worker_batch = batch.get_model_worker_batch()
embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
ret = EmbeddingBatchResult(embeddings=embeddings)
return ret

def run_batch(
self, batch: ScheduleBatch
) -> Union[GenerationBatchResult, EmbeddingBatchResult]:
"""Run a batch."""
return self.run_batch_exp(batch)

self.forward_ct += 1

# Whether to run the profiler
Expand Down
46 changes: 22 additions & 24 deletions python/sglang/srt/managers/scheduler_output_processor_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,29 +49,29 @@ def process_batch_result_prefill(
next_token_ids,
extend_input_len_per_req,
extend_logprob_start_len_per_req,
copy_done,
) = (
result.logits_output,
result.next_token_ids,
result.extend_input_len_per_req,
result.extend_logprob_start_len_per_req,
result.copy_done,
)

if self.enable_overlap:
logits_output, next_token_ids, _ = (
self.tp_worker.resolve_last_batch_result(launch_done)
)
else:
# Move next_token_ids and logprobs to cpu
next_token_ids = next_token_ids.tolist()
if batch.return_logprob:
if logits_output.next_token_logprobs is not None:
logits_output.next_token_logprobs = (
logits_output.next_token_logprobs.tolist()
)
if logits_output.input_token_logprobs is not None:
logits_output.input_token_logprobs = tuple(
logits_output.input_token_logprobs.tolist()
)
if copy_done is not None:
copy_done.synchronize()

# Move next_token_ids and logprobs to cpu
next_token_ids = next_token_ids.tolist()
if batch.return_logprob:
if logits_output.next_token_logprobs is not None:
logits_output.next_token_logprobs = (
logits_output.next_token_logprobs.tolist()
)
if logits_output.input_token_logprobs is not None:
logits_output.input_token_logprobs = tuple(
logits_output.input_token_logprobs.tolist()
)

hidden_state_offset = 0

Expand Down Expand Up @@ -206,20 +206,18 @@ def process_batch_result_decode(
result: GenerationBatchResult,
launch_done: Optional[threading.Event] = None,
):
logits_output, next_token_ids, can_run_cuda_graph = (
logits_output, next_token_ids, can_run_cuda_graph, copy_done = (
result.logits_output,
result.next_token_ids,
result.can_run_cuda_graph,
result.copy_done,
)
self.num_generated_tokens += len(batch.reqs)

if self.enable_overlap:
logits_output, next_token_ids, can_run_cuda_graph = (
self.tp_worker.resolve_last_batch_result(launch_done)
)
next_token_logprobs = logits_output.next_token_logprobs
elif batch.spec_algorithm.is_none():
# spec decoding handles output logprobs inside verify process.
if copy_done is not None:
copy_done.synchronize()

if batch.spec_algorithm.is_none():
next_token_ids = next_token_ids.tolist()
if batch.return_logprob:
next_token_logprobs = logits_output.next_token_logprobs.tolist()
Expand Down
8 changes: 2 additions & 6 deletions python/sglang/srt/managers/tp_worker_overlap_thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,12 +234,8 @@ def forward_batch_generation(
self, model_worker_batch: ModelWorkerBatch
) -> ForwardBatchOutput:
# Create a new copy of sampling_info because it will be updated in-place by the scheduler for the next batch.
sampling_info = model_worker_batch.sampling_info
sampling_info.update_penalties()
model_worker_batch.sampling_info = self.cur_sampling_info = dataclasses.replace(
sampling_info,
sampling_info_done=threading.Event(),
penalizer_orchestrator=None,
model_worker_batch.sampling_info = self.cur_sampling_info = (
model_worker_batch.sampling_info.copy_for_forward()
)

# A cuda stream sync here to avoid the cuda illegal memory access error.
Expand Down
3 changes: 3 additions & 0 deletions python/sglang/srt/model_executor/forward_batch_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -910,6 +910,9 @@ class ForwardBatchOutput:
pp_proxy_tensors: Optional[PPProxyTensors] = None
can_run_cuda_graph: bool = False

# For overlap scheduling
copy_done: Optional[torch.cuda.Event] = None


def enable_num_token_non_padded(server_args):
return get_moe_expert_parallel_world_size() > 1
Expand Down
9 changes: 9 additions & 0 deletions python/sglang/srt/sampling/sampling_batch_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,15 @@ def merge_batch(self, other: "SamplingBatchInfo"):
self.need_top_k_sampling |= other.need_top_k_sampling
self.need_min_p_sampling |= other.need_min_p_sampling

def copy_for_forward(self):
# Accumulate the penalty into a pre-allocated buffer to get rid of the dependency of `penalizer_orchestrator` later
self.update_penalties()
return dataclasses.replace(
self,
sampling_info_done=threading.Event(),
penalizer_orchestrator=None,
)


def merge_bias_tensor(
lhs: Optional[torch.Tensor],
Expand Down
Loading