diff --git a/python/sglang/srt/disaggregation/decode.py b/python/sglang/srt/disaggregation/decode.py index a0b6de7f406..cf87d62d7b3 100644 --- a/python/sglang/srt/disaggregation/decode.py +++ b/python/sglang/srt/disaggregation/decode.py @@ -747,11 +747,13 @@ def event_loop_normal_disagg_decode(self: Scheduler): @torch.no_grad() def event_loop_overlap_disagg_decode(self: Scheduler): - result_queue = deque() + self.result_queue = deque() self.last_batch: Optional[ScheduleBatch] = None self.last_batch_in_queue = False # last batch is modified in-place, so we need another variable to track if it's extend while True: + self.launch_last_batch_sample_if_needed() + recv_reqs = self.recv_requests() self.process_input_requests(recv_reqs) # polling and allocating kv cache @@ -774,13 +776,13 @@ def event_loop_overlap_disagg_decode(self: Scheduler): None, delay_process=True ) if batch_: - result_queue.append((batch_.copy(), result)) + self.result_queue.append((batch_.copy(), result)) last_batch_in_queue = True else: if prepare_mlp_sync_flag: self.prepare_mlp_sync_batch(batch) result = self.run_batch(batch) - result_queue.append((batch.copy(), result)) + self.result_queue.append((batch.copy(), result)) if (self.last_batch is None) or (not self.last_batch_in_queue): # Create a dummy first batch to start the pipeline for overlap schedule. @@ -798,12 +800,12 @@ def event_loop_overlap_disagg_decode(self: Scheduler): None, delay_process=True ) if batch: - result_queue.append((batch.copy(), result)) + self.result_queue.append((batch.copy(), result)) last_batch_in_queue = True # Process the results of the previous batch but skip if the last batch is extend if self.last_batch and self.last_batch_in_queue: - tmp_batch, tmp_result = result_queue.popleft() + tmp_batch, tmp_result = self.result_queue.popleft() tmp_batch.next_batch_sampling_info = ( self.tp_worker.cur_sampling_info if batch else None ) diff --git a/python/sglang/srt/disaggregation/prefill.py b/python/sglang/srt/disaggregation/prefill.py index 3393a32fb3a..b761ad7ac1d 100644 --- a/python/sglang/srt/disaggregation/prefill.py +++ b/python/sglang/srt/disaggregation/prefill.py @@ -321,6 +321,8 @@ def event_loop_overlap_disagg_prefill(self: Scheduler) -> None: self.result_queue = deque() while True: + self.launch_last_batch_sample_if_needed() + recv_reqs = self.recv_requests() self.process_input_requests(recv_reqs) self.waiting_queue.extend( @@ -368,7 +370,6 @@ def process_batch_result_disagg_prefill( self: Scheduler, batch: ScheduleBatch, result: GenerationBatchResult, - launch_done: Optional[threading.Event] = None, ) -> None: """ Transfer kv for prefill completed requests and add it into disagg_prefill_inflight_queue @@ -379,31 +380,30 @@ def process_batch_result_disagg_prefill( next_token_ids, extend_input_len_per_req, extend_logprob_start_len_per_req, + copy_done, ) = ( result.logits_output, result.next_token_ids, result.extend_input_len_per_req, result.extend_logprob_start_len_per_req, + result.copy_done, ) + if copy_done is not None: + copy_done.synchronize() + logprob_pt = 0 # Transfer kv for prefill completed requests and add it into disagg_prefill_inflight_queue - if self.enable_overlap: - # wait - logits_output, next_token_ids, _ = self.tp_worker.resolve_last_batch_result( - launch_done - ) - else: - next_token_ids = result.next_token_ids.tolist() - if batch.return_logprob: - if logits_output.next_token_logprobs is not None: - logits_output.next_token_logprobs = ( - logits_output.next_token_logprobs.tolist() - ) - if logits_output.input_token_logprobs is not None: - logits_output.input_token_logprobs = tuple( - logits_output.input_token_logprobs.tolist() - ) + next_token_ids = result.next_token_ids.tolist() + if batch.return_logprob: + if logits_output.next_token_logprobs is not None: + logits_output.next_token_logprobs = ( + logits_output.next_token_logprobs.tolist() + ) + if logits_output.input_token_logprobs is not None: + logits_output.input_token_logprobs = tuple( + logits_output.input_token_logprobs.tolist() + ) hidden_state_offset = 0 for i, (req, next_token_id) in enumerate( diff --git a/python/sglang/srt/managers/overlap_utils.py b/python/sglang/srt/managers/overlap_utils.py index d512ae7ec99..1a011717ac1 100644 --- a/python/sglang/srt/managers/overlap_utils.py +++ b/python/sglang/srt/managers/overlap_utils.py @@ -37,8 +37,7 @@ def update_ct(self, bs: int) -> int: return cur_future_ct def resolve_future(self, model_worker_batch: ModelWorkerBatch): - input_ids = model_worker_batch.input_ids - _resolve_future_token_ids(input_ids, self.token_ids_buf) + _resolve_future_token_ids(model_worker_batch.input_ids, self.token_ids_buf) def update_next_future(self, future_ct: int, bs: int): return torch.arange( diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index af14d95d864..2403070e3c7 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -886,9 +886,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): # This is an optimization to reduce the overhead of the prefill check. batch_is_full: bool = False - # Events - launch_done: Optional[threading.Event] = None - # For chunked prefill in PP chunked_req: Optional[Req] = None @@ -1877,7 +1874,6 @@ def get_model_worker_batch( ) ), extend_input_logprob_token_ids=self.extend_input_logprob_token_ids, - launch_done=self.launch_done, is_prefill_only=self.is_prefill_only, ) @@ -2018,8 +2014,8 @@ class ModelWorkerBatch: capture_hidden_mode: CaptureHiddenMode = None hicache_consumer_index: int = -1 - # Overlap event - launch_done: Optional[threading.Event] = None + # Overlap scheduler related + delay_sample_launch: bool = False # Whether this batch is prefill-only (no token generation needed) is_prefill_only: bool = False diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index e5d6fab7b86..8873524fd39 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -25,12 +25,14 @@ from dataclasses import dataclass from http import HTTPStatus from types import SimpleNamespace -from typing import Dict, List, Optional, Tuple, Union +from typing import Deque, Dict, List, Optional, Tuple, Union import psutil import setproctitle import torch import zmq +from torch.cuda import Stream as CudaStream +from torch.cuda import StreamContext as CudaStreamContext from torch.distributed import barrier from sglang.global_config import global_config @@ -112,8 +114,10 @@ UpdateWeightsFromTensorReqInput, ) from sglang.srt.managers.mm_utils import init_embedding_cache +from sglang.srt.managers.overlap_utils import FutureMap from sglang.srt.managers.schedule_batch import ( FINISH_ABORT, + ModelWorkerBatch, MultimodalInputs, Req, RequestStage, @@ -139,15 +143,13 @@ SchedulerUpdateWeightsMixin, ) from sglang.srt.managers.session_controller import Session -from sglang.srt.managers.tp_worker import TpModelWorker -from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient from sglang.srt.managers.utils import validate_input_length from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache from sglang.srt.mem_cache.hiradix_cache import HiRadixCache from sglang.srt.mem_cache.radix_cache import RadixCache from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache from sglang.srt.model_executor.forward_batch_info import ( - ForwardBatchOutput, + ForwardBatch, ForwardMode, PPProxyTensors, ) @@ -201,40 +203,48 @@ @dataclass class GenerationBatchResult: - logits_output: Optional[LogitsProcessorOutput] - pp_hidden_states_proxy_tensors: Optional[PPProxyTensors] - next_token_ids: Optional[List[int]] - can_run_cuda_graph: bool + logits_output: Optional[LogitsProcessorOutput] = None + pp_hidden_states_proxy_tensors: Optional[PPProxyTensors] = None + next_token_ids: Optional[torch.Tensor] = None + num_accepted_tokens: Optional[int] = None + can_run_cuda_graph: bool = False # For output processing - extend_input_len_per_req: List[int] - extend_logprob_start_len_per_req: List[int] - - @classmethod - def from_forward_batch_output( - cls, - forward_batch_output: ForwardBatchOutput, - extend_input_len_per_req: List[int], - extend_logprob_start_len_per_req: List[int], - ): - # TODO(lsyin): remove this workaround logic and try to unify output classes - - return cls( - logits_output=forward_batch_output.logits_output, - pp_hidden_states_proxy_tensors=forward_batch_output.pp_proxy_tensors, - next_token_ids=forward_batch_output.next_token_ids, - extend_input_len_per_req=extend_input_len_per_req, - extend_logprob_start_len_per_req=extend_logprob_start_len_per_req, - can_run_cuda_graph=forward_batch_output.can_run_cuda_graph, - ) + extend_input_len_per_req: Optional[List[int]] = None + extend_logprob_start_len_per_req: Optional[List[int]] = None + + # For overlap scheduling + copy_done: Optional[torch.cuda.Event] = None + delay_sample_launch: bool = False + forward_batch: Optional[ForwardBatch] = None + future_map_ct: Optional[int] = None + + def copy_to_cpu(self, return_logprob: bool = False): + """Copy tensors to CPU in overlap scheduling. + Only the tensors which are needed for processing results are copied, + e.g., next_token_ids, logits outputs + """ + if return_logprob: + if self.logits_output.next_token_logits is not None: + self.logits_output.next_token_logits = ( + self.logits_output.next_token_logits.to("cpu", non_blocking=True) + ) + if self.logits_output.input_token_logprobs is not None: + self.logits_output.input_token_logprobs = ( + self.logits_output.input_token_logprobs.to("cpu", non_blocking=True) + ) + if self.logits_output.hidden_states is not None: + self.logits_output.hidden_states = self.logits_output.hidden_states.to( + "cpu", non_blocking=True + ) + self.next_token_ids = self.next_token_ids.to("cpu", non_blocking=True) + self.copy_done.record() @classmethod def from_pp_proxy( cls, logits_output, next_pp_outputs: PPProxyTensors, can_run_cuda_graph ): - # TODO(lsyin): also simplify this logic - # Current PP implementation in scheduler is not compatible with ForwardBatchOutput - # Maybe introduce a ProxyBatchOutput for PP and the original ForwardBatchOutput for TP + # TODO(lsyin): refactor PP and avoid using dict proxy_dict = next_pp_outputs.tensors return cls( logits_output=logits_output, @@ -388,12 +398,10 @@ def __init__( logger.info("Overlap scheduler is disabled for embedding models.") # Launch a tensor parallel worker - if self.enable_overlap: - TpWorkerClass = TpModelWorkerClient - else: - TpWorkerClass = TpModelWorker - self.tp_worker = TpWorkerClass( + from sglang.srt.managers.tp_worker import TpModelWorker + + self.tp_worker = TpModelWorker( server_args=server_args, gpu_id=gpu_id, tp_rank=tp_rank, @@ -525,9 +533,11 @@ def __init__( self.kv_transfer_speed_gb_s: float = 0.0 self.kv_transfer_latency_ms: float = 0.0 self.sessions: Dict[str, Session] = {} - self.current_stream = torch.get_device_module(self.device).current_stream() + self.default_stream: CudaStream = torch.get_device_module( + self.device + ).current_stream() if self.device == "cpu": - self.current_stream.synchronize = lambda: None # No-op for CPU + self.default_stream.synchronize = lambda: None # No-op for CPU self.forward_sleep_time = None # Init chunked prefill @@ -618,6 +628,9 @@ def __init__( # Init prefill kv split size when deterministic inference is enabled with various attention backends self.init_deterministic_inference_config() + # Init overlap + self.init_overlap() + # Init request dispatcher self._request_dispatcher = TypeBasedDispatcher( [ @@ -931,6 +944,32 @@ def init_disaggregation(self): # The prefill requests that are in the middle of kv sending self.disagg_prefill_inflight_queue: List[Req] = [] + def init_overlap(self): + if not self.enable_overlap: + return + + self.forward_stream: CudaStream = torch.get_device_module(self.device).Stream() + self.forward_stream_ctx: CudaStreamContext = torch.get_device_module( + self.device + ).stream(self.forward_stream) + self.copy_stream: CudaStream = torch.get_device_module(self.device).Stream() + self.copy_stream_ctx: CudaStreamContext = torch.get_device_module( + self.device + ).stream(self.copy_stream) + + self.future_map = FutureMap(self.max_running_requests, self.device) + self.batch_record_buf = [None] * 2 + self.batch_record_ct = 0 + + def record_batch_in_overlap(self, model_worker_batch: ModelWorkerBatch): + # FIXME(lsyin): hacky way to keep a reference to avoid GPU tensors being freed by torch GC + # NOTE: More Reliable: record all tensors into the forward stream + # NOTE: - for all future tensors, we shall always read from future map + # - for all non-future tensors (produced only by schedule stream), + # we shall keep its reference not being release during all the forwarding pass + self.batch_record_ct = (self.batch_record_ct + 1) % 2 + self.batch_record_buf[self.batch_record_ct] = model_worker_batch + def init_moe_config(self): if hasattr(self.model_config.hf_config, "num_experts_per_tok"): initialize_moe_config(self.server_args) @@ -957,9 +996,11 @@ def event_loop_normal(self): @DynamicGradMode() def event_loop_overlap(self): """A scheduler loop that overlaps the CPU processing and GPU computation.""" - self.result_queue = deque() + self.result_queue: Deque[Tuple[ScheduleBatch, GenerationBatchResult]] = deque() while True: + self.launch_last_batch_sample_if_needed() + recv_reqs = self.recv_requests() self.process_input_requests(recv_reqs) @@ -967,7 +1008,6 @@ def event_loop_overlap(self): self.cur_batch = batch if batch: - batch.launch_done = threading.Event() result = self.run_batch(batch) self.result_queue.append((batch.copy(), result)) @@ -979,7 +1019,7 @@ def event_loop_overlap(self): forward_mode=ForwardMode.DUMMY_FIRST, next_batch_sampling_info=self.tp_worker.cur_sampling_info, ) - self.process_batch_result(tmp_batch, None, batch.launch_done) + self.process_batch_result(tmp_batch, None) if self.last_batch: # Process the results of the last batch @@ -987,10 +1027,7 @@ def event_loop_overlap(self): tmp_batch.next_batch_sampling_info = ( self.tp_worker.cur_sampling_info if batch else None ) - # NOTE: we should use current launched batch's launch_done event Instead of the last batch's - self.process_batch_result( - tmp_batch, tmp_result, batch.launch_done if batch else None - ) + self.process_batch_result(tmp_batch, tmp_result) elif batch is None: # When the server is idle, do self-check and re-init some states self.self_check_during_idle() @@ -2055,18 +2092,62 @@ def run_batch( # FIXME(lsyin): remove this if and finally unify the abstraction batch_or_worker_batch = batch.get_model_worker_batch() - forward_batch_output = self.model_worker.forward_batch_generation( - batch_or_worker_batch - ) + if self.enable_overlap: + # FIXME: remove this assert + assert isinstance(batch_or_worker_batch, ModelWorkerBatch) + model_worker_batch = batch_or_worker_batch + self.record_batch_in_overlap(model_worker_batch) + + # Sampling info will be modified during forward + model_worker_batch.sampling_info = self.tp_worker.cur_sampling_info = ( + model_worker_batch.sampling_info.copy_for_forward() + ) + + bs = len(model_worker_batch.seq_lens) + cur_future_map_ct = self.future_map.update_ct(bs) + + with self.forward_stream_ctx: + self.forward_stream.wait_stream(self.default_stream) + self.future_map.resolve_future(model_worker_batch) + if batch.sampling_info.grammars is not None: + model_worker_batch.delay_sample_launch = True + batch_result = self.model_worker.forward_batch_generation( + batch_or_worker_batch + ) + # FIXME(lsyin): maybe move this to forward_batch_generation + batch_result.copy_done = torch.get_device_module( + self.device + ).Event() + if not model_worker_batch.delay_sample_launch: + self.future_map.store_to_map( + cur_future_map_ct, bs, batch_result.next_token_ids + ) + batch_result.copy_to_cpu() + else: + batch_result.future_map_ct = cur_future_map_ct + + # FIXME(lsyin): move this assignment elsewhere + maybe_future_next_token_ids = self.future_map.update_next_future( + cur_future_map_ct, bs + ) + else: + batch_result = self.model_worker.forward_batch_generation( + batch_or_worker_batch + ) + maybe_future_next_token_ids = batch_result.next_token_ids + copy_done = None if not self.spec_algorithm.is_none(): # TODO(lsyin): unify this metric-updating logic with non-spec, and move it to decode processing - self.udpate_spec_metrics( - batch.batch_size(), forward_batch_output.num_accepted_tokens + self.update_spec_metrics( + batch.batch_size(), batch_result.num_accepted_tokens ) - # update batch's output ids - batch.output_ids = forward_batch_output.next_token_ids + # NOTE: maybe_future_next_token_ids is used in ScheduleBatch, + # which can probably be replaced by future_indices later [TODO(lsyin)]. + # we shall still keep the original outputs, e.g. next_token_ids + # in the GenerationBatchOutput for processing after copy_done. + batch.output_ids = maybe_future_next_token_ids # These 2 values are needed for processing the output, but the values can be # modified by overlap schedule. So we have to copy them here so that @@ -2083,36 +2164,60 @@ def run_batch( else: extend_logprob_start_len_per_req = None - return GenerationBatchResult.from_forward_batch_output( - forward_batch_output=forward_batch_output, - extend_input_len_per_req=extend_input_len_per_req, - extend_logprob_start_len_per_req=extend_logprob_start_len_per_req, + batch_result.extend_input_len_per_req = extend_input_len_per_req + batch_result.extend_logprob_start_len_per_req = ( + extend_logprob_start_len_per_req ) + return batch_result else: # embedding or reward model model_worker_batch = batch.get_model_worker_batch() embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch) ret = EmbeddingBatchResult(embeddings=embeddings) return ret + def launch_last_batch_sample_if_needed( + self, + ) -> Union[GenerationBatchResult, EmbeddingBatchResult]: + if len(self.result_queue) == 0: + return + + tmp_batch, tmp_result = self.result_queue.popleft() + + tmp_result: GenerationBatchResult + if not tmp_result.delay_sample_launch: + self.result_queue.appendleft((tmp_batch, tmp_result)) + return + + with self.forward_stream_ctx: + self.forward_stream.wait_stream(self.default_stream) + tmp_result.next_token_ids = self.model_worker.model_runner.sample( + tmp_result.logits_output, + tmp_result.forward_batch, + ) + ct, bs = tmp_result.future_map_ct, len(tmp_batch.reqs) + self.future_map.store_to_map(ct, bs, tmp_result.next_token_ids) + tmp_result.copy_to_cpu() + self.result_queue.appendleft((tmp_batch, tmp_result)) + def process_batch_result( self, batch: ScheduleBatch, result: Union[GenerationBatchResult, EmbeddingBatchResult], - launch_done: Optional[threading.Event] = None, ): if batch.forward_mode.is_decode(): - self.process_batch_result_decode(batch, result, launch_done) + self.process_batch_result_decode(batch, result) if self.enable_trace: trace_slice_batch("decode loop", batch.reqs) elif batch.forward_mode.is_extend(): - self.process_batch_result_prefill(batch, result, launch_done) + self.process_batch_result_prefill(batch, result) if self.enable_trace: trace_slice_batch("prefill", batch.reqs) elif batch.forward_mode.is_idle(): if self.enable_overlap: - self.tp_worker.resolve_last_batch_result(launch_done) + if result.copy_done is not None: + result.copy_done.synchronize() self.set_next_batch_sampling_info_done(batch) elif batch.forward_mode.is_dummy_first(): self.set_next_batch_sampling_info_done(batch) @@ -2329,7 +2434,7 @@ def set_next_batch_sampling_info_done(self, batch: ScheduleBatch): if batch.next_batch_sampling_info: if batch.next_batch_sampling_info.grammars is not None: batch.next_batch_sampling_info.update_regex_vocab_mask() - self.current_stream.synchronize() + self.default_stream.synchronize() batch.next_batch_sampling_info.sampling_info_done.set() def watchdog_thread(self): diff --git a/python/sglang/srt/managers/scheduler_metrics_mixin.py b/python/sglang/srt/managers/scheduler_metrics_mixin.py index 2af5ab5ab9e..4fa4bfee1dc 100644 --- a/python/sglang/srt/managers/scheduler_metrics_mixin.py +++ b/python/sglang/srt/managers/scheduler_metrics_mixin.py @@ -69,7 +69,7 @@ def init_kv_events(self: Scheduler, kv_events_config: Optional[str]): kv_events_config, self.attn_dp_rank ) - def udpate_spec_metrics(self, bs: int, num_accepted_tokens: int): + def update_spec_metrics(self, bs: int, num_accepted_tokens: int): self.spec_num_total_accepted_tokens += num_accepted_tokens + bs self.spec_num_total_forward_ct += bs self.num_generated_tokens += num_accepted_tokens diff --git a/python/sglang/srt/managers/scheduler_output_processor_mixin.py b/python/sglang/srt/managers/scheduler_output_processor_mixin.py index e307a689950..5a14ba4fae1 100644 --- a/python/sglang/srt/managers/scheduler_output_processor_mixin.py +++ b/python/sglang/srt/managers/scheduler_output_processor_mixin.py @@ -39,7 +39,6 @@ def process_batch_result_prefill( self: Scheduler, batch: ScheduleBatch, result: Union[GenerationBatchResult, EmbeddingBatchResult], - launch_done: Optional[threading.Event] = None, ): skip_stream_req = None @@ -49,29 +48,29 @@ def process_batch_result_prefill( next_token_ids, extend_input_len_per_req, extend_logprob_start_len_per_req, + copy_done, ) = ( result.logits_output, result.next_token_ids, result.extend_input_len_per_req, result.extend_logprob_start_len_per_req, + result.copy_done, ) - if self.enable_overlap: - logits_output, next_token_ids, _ = ( - self.tp_worker.resolve_last_batch_result(launch_done) - ) - else: - # Move next_token_ids and logprobs to cpu - next_token_ids = next_token_ids.tolist() - if batch.return_logprob: - if logits_output.next_token_logprobs is not None: - logits_output.next_token_logprobs = ( - logits_output.next_token_logprobs.tolist() - ) - if logits_output.input_token_logprobs is not None: - logits_output.input_token_logprobs = tuple( - logits_output.input_token_logprobs.tolist() - ) + if copy_done is not None: + copy_done.synchronize() + + # Move next_token_ids and logprobs to cpu + next_token_ids = next_token_ids.tolist() + if batch.return_logprob: + if logits_output.next_token_logprobs is not None: + logits_output.next_token_logprobs = ( + logits_output.next_token_logprobs.tolist() + ) + if logits_output.input_token_logprobs is not None: + logits_output.input_token_logprobs = tuple( + logits_output.input_token_logprobs.tolist() + ) hidden_state_offset = 0 @@ -204,22 +203,19 @@ def process_batch_result_decode( self: Scheduler, batch: ScheduleBatch, result: GenerationBatchResult, - launch_done: Optional[threading.Event] = None, ): - logits_output, next_token_ids, can_run_cuda_graph = ( + logits_output, next_token_ids, can_run_cuda_graph, copy_done = ( result.logits_output, result.next_token_ids, result.can_run_cuda_graph, + result.copy_done, ) self.num_generated_tokens += len(batch.reqs) - if self.enable_overlap: - logits_output, next_token_ids, can_run_cuda_graph = ( - self.tp_worker.resolve_last_batch_result(launch_done) - ) - next_token_logprobs = logits_output.next_token_logprobs - elif batch.spec_algorithm.is_none(): - # spec decoding handles output logprobs inside verify process. + if copy_done is not None: + copy_done.synchronize() + + if batch.spec_algorithm.is_none(): next_token_ids = next_token_ids.tolist() if batch.return_logprob: next_token_logprobs = logits_output.next_token_logprobs.tolist() diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 475305a2fc3..051df74d724 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -15,14 +15,12 @@ from __future__ import annotations import logging -import threading -from typing import TYPE_CHECKING, Optional, Tuple, Union +from typing import TYPE_CHECKING, Optional import torch from sglang.srt.configs.model_config import ModelConfig from sglang.srt.distributed import get_pp_group, get_world_group -from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.managers.io_struct import ( DestroyWeightsUpdateGroupReqInput, GetWeightsByNameReqInput, @@ -36,13 +34,10 @@ UpdateWeightsFromTensorReqInput, ) from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict +from sglang.srt.managers.scheduler import GenerationBatchResult from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator from sglang.srt.mem_cache.memory_pool import ReqToTokenPool -from sglang.srt.model_executor.forward_batch_info import ( - ForwardBatch, - ForwardBatchOutput, - PPProxyTensors, -) +from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.server_args import ServerArgs from sglang.srt.utils import MultiprocessingSerializer, broadcast_pyobj, set_random_seed @@ -236,9 +231,8 @@ def get_memory_pool(self): def forward_batch_generation( self, model_worker_batch: ModelWorkerBatch, - launch_done: Optional[threading.Event] = None, is_verify: bool = False, - ) -> ForwardBatchOutput: + ) -> GenerationBatchResult: # update the consumer index of hicache to the running batch self.set_hicache_consumer(model_worker_batch.hicache_consumer_index) @@ -256,32 +250,43 @@ def forward_batch_generation( logits_output, can_run_cuda_graph = self.model_runner.forward( forward_batch, pp_proxy_tensors=pp_proxy_tensors ) - if launch_done is not None: - launch_done.set() - - skip_sample = is_verify or model_worker_batch.is_prefill_only - next_token_ids = None - - if not skip_sample: - next_token_ids = self.model_runner.sample(logits_output, forward_batch) - elif model_worker_batch.return_logprob and not is_verify: - # NOTE: Compute logprobs without full sampling - self.model_runner.compute_logprobs_only( - logits_output, model_worker_batch - ) - - return ForwardBatchOutput( + batch_result = GenerationBatchResult( logits_output=logits_output, - next_token_ids=next_token_ids, can_run_cuda_graph=can_run_cuda_graph, ) + + if is_verify: + # Skip sampling and return logits for target forward + return batch_result + + if model_worker_batch.delay_sample_launch: + batch_result.delay_sample_launch = True + batch_result.forward_batch = forward_batch + return batch_result + + if model_worker_batch.is_prefill_only: + # For prefill-only requests, create dummy token IDs on CPU + batch_result.next_token_ids = torch.zeros_like( + model_worker_batch.input_ids, dtype=torch.long + ) + if model_worker_batch.return_logprob: + # NOTE: Compute logprobs without full sampling + self.model_runner.compute_logprobs_only( + logits_output, model_worker_batch + ) + else: + batch_result.next_token_ids = self.model_runner.sample( + logits_output, forward_batch + ) + + return batch_result else: pp_proxy_tensors, can_run_cuda_graph = self.model_runner.forward( forward_batch, pp_proxy_tensors=pp_proxy_tensors, ) - return ForwardBatchOutput( - pp_proxy_tensors=pp_proxy_tensors, + return GenerationBatchResult( + pp_hidden_states_proxy_tensors=pp_proxy_tensors, can_run_cuda_graph=can_run_cuda_graph, ) diff --git a/python/sglang/srt/managers/tp_worker_overlap_thread.py b/python/sglang/srt/managers/tp_worker_overlap_thread.py index 3e024036103..3491dce7d5e 100644 --- a/python/sglang/srt/managers/tp_worker_overlap_thread.py +++ b/python/sglang/srt/managers/tp_worker_overlap_thread.py @@ -232,12 +232,8 @@ def forward_batch_generation( self, model_worker_batch: ModelWorkerBatch ) -> ForwardBatchOutput: # Create a new copy of sampling_info because it will be updated in-place by the scheduler for the next batch. - sampling_info = model_worker_batch.sampling_info - sampling_info.update_penalties() - model_worker_batch.sampling_info = self.cur_sampling_info = dataclasses.replace( - sampling_info, - sampling_info_done=threading.Event(), - penalizer_orchestrator=None, + model_worker_batch.sampling_info = self.cur_sampling_info = ( + model_worker_batch.sampling_info.copy_for_forward() ) # A cuda stream sync here to avoid the cuda illegal memory access error. diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index 4309f52118e..a9a29fbc734 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -902,17 +902,6 @@ def can_run_tbo(self): return self.tbo_split_seq_index is not None -@dataclass -class ForwardBatchOutput: - # FIXME(lsyin): unify the forward batch output between different spec and parallelism - # need to be more organized - logits_output: Optional[torch.Tensor] = None - next_token_ids: Optional[torch.Tensor] = None - num_accepted_tokens: Optional[int] = None - pp_proxy_tensors: Optional[PPProxyTensors] = None - can_run_cuda_graph: bool = False - - def enable_num_token_non_padded(server_args): return get_moe_expert_parallel_world_size() > 1 diff --git a/python/sglang/srt/sampling/sampling_batch_info.py b/python/sglang/srt/sampling/sampling_batch_info.py index 8d3e48bc245..d246ac3c34c 100644 --- a/python/sglang/srt/sampling/sampling_batch_info.py +++ b/python/sglang/srt/sampling/sampling_batch_info.py @@ -370,6 +370,15 @@ def merge_batch(self, other: "SamplingBatchInfo"): self.need_top_k_sampling |= other.need_top_k_sampling self.need_min_p_sampling |= other.need_min_p_sampling + def copy_for_forward(self): + # Accumulate the penalty into a pre-allocated buffer to get rid of the dependency of `penalizer_orchestrator` later + self.update_penalties() + return dataclasses.replace( + self, + sampling_info_done=threading.Event(), + penalizer_orchestrator=None, + ) + def merge_bias_tensor( lhs: Optional[torch.Tensor], diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index 08d659d02d0..d23f9e3b0ee 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -19,11 +19,11 @@ get_last_loc, global_server_args_dict, ) +from sglang.srt.managers.scheduler import GenerationBatchResult from sglang.srt.managers.tp_worker import TpModelWorker from sglang.srt.model_executor.forward_batch_info import ( CaptureHiddenMode, ForwardBatch, - ForwardBatchOutput, ForwardMode, ) from sglang.srt.server_args import ServerArgs @@ -429,7 +429,7 @@ def init_cuda_graphs(self): def draft_model_runner(self): return self.model_runner - def forward_batch_generation(self, batch: ScheduleBatch) -> ForwardBatchOutput: + def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResult: """Run speculative decoding forward. NOTE: Many states of batch is modified as you go through. It is not guaranteed that @@ -449,7 +449,7 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> ForwardBatchOutput: self.forward_draft_extend( batch, logits_output.hidden_states, next_token_ids, seq_lens_cpu ) - return ForwardBatchOutput( + return GenerationBatchResult( logits_output=logits_output, next_token_ids=next_token_ids, num_accepted_tokens=0, @@ -472,7 +472,7 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> ForwardBatchOutput: # decode is not finished self.forward_draft_extend_after_decode(batch) - return ForwardBatchOutput( + return GenerationBatchResult( logits_output=logits_output, next_token_ids=verify_output.verified_id, num_accepted_tokens=sum(verify_output.accept_length_per_req_cpu), @@ -513,12 +513,10 @@ def forward_target_extend( # We need the full hidden states to prefill the KV cache of the draft model. model_worker_batch = batch.get_model_worker_batch() model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL - forward_batch_output = self.target_worker.forward_batch_generation( - model_worker_batch - ) + batch_result = self.target_worker.forward_batch_generation(model_worker_batch) logits_output, next_token_ids = ( - forward_batch_output.logits_output, - forward_batch_output.next_token_ids, + batch_result.logits_output, + batch_result.next_token_ids, ) return ( logits_output, @@ -822,12 +820,12 @@ def verify(self, batch: ScheduleBatch, spec_info: EagleVerifyInput): ).cpu() # Forward - forward_batch_output = self.target_worker.forward_batch_generation( + batch_result = self.target_worker.forward_batch_generation( model_worker_batch, is_verify=True ) logits_output, can_run_cuda_graph = ( - forward_batch_output.logits_output, - forward_batch_output.can_run_cuda_graph, + batch_result.logits_output, + batch_result.can_run_cuda_graph, ) vocab_mask = None diff --git a/python/sglang/srt/speculative/ngram_worker.py b/python/sglang/srt/speculative/ngram_worker.py index 97aa620ceb3..d2197023d09 100644 --- a/python/sglang/srt/speculative/ngram_worker.py +++ b/python/sglang/srt/speculative/ngram_worker.py @@ -6,8 +6,9 @@ from sgl_kernel.speculative import reconstruct_indices_from_tree_mask from sglang.srt.managers.schedule_batch import ScheduleBatch +from sglang.srt.managers.scheduler import GenerationBatchResult from sglang.srt.managers.tp_worker import TpModelWorker -from sglang.srt.model_executor.forward_batch_info import ForwardBatchOutput, ForwardMode +from sglang.srt.model_executor.forward_batch_info import ForwardMode from sglang.srt.server_args import ServerArgs from sglang.srt.speculative.cpp_ngram.ngram_cache import NgramCache from sglang.srt.speculative.ngram_utils import NgramVerifyInput @@ -207,18 +208,18 @@ def _update_ngram_cache(self, batch: ScheduleBatch): batch_tokens.append(put_ids) self.ngram_cache.batch_put(batch_tokens) - def forward_batch_generation(self, batch: ScheduleBatch) -> ForwardBatchOutput: + def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResult: self._prepare_for_speculative_decoding(batch) model_worker_batch = batch.get_model_worker_batch() num_accepted_tokens = 0 if model_worker_batch.forward_mode.is_target_verify(): - forward_batch_output = self.target_worker.forward_batch_generation( + batch_result = self.target_worker.forward_batch_generation( model_worker_batch, is_verify=True ) logits_output, can_run_cuda_graph = ( - forward_batch_output.logits_output, - forward_batch_output.can_run_cuda_graph, + batch_result.logits_output, + batch_result.can_run_cuda_graph, ) verify_input = model_worker_batch.spec_info logits_output, next_token_ids, num_accepted_tokens = verify_input.verify( @@ -228,16 +229,16 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> ForwardBatchOutput: batch.forward_mode = ForwardMode.DECODE else: - forward_batch_output = self.target_worker.forward_batch_generation( + batch_result = self.target_worker.forward_batch_generation( model_worker_batch ) logits_output, next_token_ids, can_run_cuda_graph = ( - forward_batch_output.logits_output, - forward_batch_output.next_token_ids, - forward_batch_output.can_run_cuda_graph, + batch_result.logits_output, + batch_result.next_token_ids, + batch_result.can_run_cuda_graph, ) - return ForwardBatchOutput( + return GenerationBatchResult( logits_output=logits_output, next_token_ids=next_token_ids, num_accepted_tokens=num_accepted_tokens, diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index e6497a52a96..8f84d63f64f 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -1160,7 +1160,7 @@ def run_bench_offline_throughput(model, other_args): *[str(x) for x in other_args], ] - print(f"{command=}") + print(f"command={' '.join(command)}") process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE) try: