diff --git a/fastdeploy/config.py b/fastdeploy/config.py index dfe7f8ab0a..3e66427920 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -187,7 +187,6 @@ def __init__( self.redundant_experts_num = 0 self.seed = 0 self.quantization = None - self.reasoning_parser = None self.pad_token_id: int = -1 self.eos_tokens_lens: int = 2 self.lm_head_fp32: bool = False @@ -555,10 +554,6 @@ def __init__( # Do profile or not self.do_profile: bool = False - # guided decoding backend - self.guided_decoding_backend: str = None - # disable any whitespace for guided decoding - self.disable_any_whitespace: bool = True self.pod_ip: str = None # enable the custom all-reduce kernel and fall back to NCCL(dist.all_reduce). self.disable_custom_all_reduce: bool = False @@ -1143,12 +1138,6 @@ class PoolerConfig: """ -class LoRAConfig: - """LoRA Config""" - - pass - - class CacheConfig: """ Configuration for the KV cache. @@ -1379,6 +1368,25 @@ def print(self): logger.info("=============================================================") +class StructuredOutputsConfig: + """ + Configuration for structured outputs + """ + + def __init__( + self, + args, + ) -> None: + self.reasoning_parser: Optional[str] = None + self.guided_decoding_backend: Optional[str] = None + # disable any whitespace for guided decoding + self.disable_any_whitespace: bool = True + + for key, value in args.items(): + if hasattr(self, key) and value != "None": + setattr(self, key, value) + + class FDConfig: """ The configuration class which contains all fastdeploy-related configuration. This @@ -1399,6 +1407,7 @@ def __init__( graph_opt_config: GraphOptimizationConfig = None, plas_attention_config: PlasAttentionConfig = None, speculative_config: SpeculativeConfig = None, + structured_outputs_config: StructuredOutputsConfig = None, tokenizer: str = None, ips: str = None, use_warmup: bool = False, @@ -1408,9 +1417,6 @@ def __init__( max_num_partial_prefills: int = 1, max_long_partial_prefills: int = 1, long_prefill_token_threshold: int = 0, - reasoning_parser: str = None, - guided_decoding_backend: Optional[str] = None, - disable_any_whitespace: bool = False, early_stop_config: Optional[Dict[str, Any]] = None, tool_parser: str = None, test_mode=False, @@ -1428,6 +1434,7 @@ def __init__( self.decoding_config: DecodingConfig = decoding_config # type: ignore self.cache_config: CacheConfig = cache_config # type: ignore self.plas_attention_config: Optional[PlasAttentionConfig] = plas_attention_config + self.structured_outputs_config: StructuredOutputsConfig = structured_outputs_config # Initialize cuda graph capture list if self.graph_opt_config.cudagraph_capture_sizes is None: self.graph_opt_config._set_cudagraph_sizes(max_num_seqs=self.scheduler_config.max_num_seqs) @@ -1474,9 +1481,7 @@ def __init__( self.max_num_partial_prefills = max_num_partial_prefills self.max_long_partial_prefills = max_long_partial_prefills self.long_prefill_token_threshold = long_prefill_token_threshold - self.reasoning_parser = reasoning_parser - self.guided_decoding_backend = guided_decoding_backend - self.disable_any_whitespace = disable_any_whitespace + self._str_to_list("innode_prefill_ports", int) if envs.FD_FOR_TORCH_MODEL_FORMAT: @@ -1498,12 +1503,12 @@ def __init__( else: self.worker_num_per_node = num_ranks - self.device_ids = ",".join([str(i) for i in range(self.worker_num_per_node)]) - self.device_ids = os.getenv("CUDA_VISIBLE_DEVICES", self.device_ids) + self.parallel_config.device_ids = ",".join([str(i) for i in range(self.worker_num_per_node)]) + self.parallel_config.device_ids = os.getenv("CUDA_VISIBLE_DEVICES", self.parallel_config.device_ids) if current_platform.is_xpu(): - self.device_ids = os.getenv("XPU_VISIBLE_DEVICES", self.device_ids) + self.parallel_config.device_ids = os.getenv("XPU_VISIBLE_DEVICES", self.parallel_config.device_ids) if current_platform.is_intel_hpu(): - self.device_ids = os.getenv("HPU_VISIBLE_DEVICES", self.device_ids) + self.parallel_config.device_ids = os.getenv("HPU_VISIBLE_DEVICES", self.parallel_config.device_ids) self.read_from_config() self.postprocess() @@ -1516,7 +1521,7 @@ def postprocess(self): """ calculate some parameters """ - self.local_device_ids = self.device_ids.split(",")[: self.parallel_config.tensor_parallel_size] + self.local_device_ids = self.parallel_config.device_ids.split(",")[: self.parallel_config.tensor_parallel_size] if self.parallel_config.tensor_parallel_size <= self.worker_num_per_node or self.node_rank == 0: self.is_master = True @@ -1547,12 +1552,15 @@ def postprocess(self): if self.model_config is not None and self.model_config.enable_mm: self.cache_config.enable_prefix_caching = False - if self.guided_decoding_backend == "auto": + if ( + self.structured_outputs_config is not None + and self.structured_outputs_config.guided_decoding_backend == "auto" + ): if current_platform.is_xpu() or self.speculative_config.method is not None: logger.warning("Speculative Decoding and XPU currently do not support Guided decoding, set off.") - self.guided_decoding_backend = "off" + self.structured_outputs_config.guided_decoding_backend = "off" else: - self.guided_decoding_backend = "xgrammar" + self.structured_outputs_config.guided_decoding_backend = "xgrammar" if self.scheduler_config.splitwise_role == "mixed": self.model_config.moe_phase = MoEPhase(phase="prefill") @@ -1627,15 +1635,18 @@ def check(self): f" max_model_len: {self.model_config.max_model_len}" ) - if self.guided_decoding_backend is not None: - assert self.guided_decoding_backend in [ + if ( + self.structured_outputs_config is not None + and self.structured_outputs_config.guided_decoding_backend is not None + ): + assert self.structured_outputs_config.guided_decoding_backend in [ "xgrammar", "XGrammar", "auto", "off", - ], f"Only support xgrammar、auto guided decoding backend, but got {self.guided_decoding_backend}." + ], f"Only support xgrammar、auto guided decoding backend, but got {self.structured_outputs_config.guided_decoding_backend}." - if self.guided_decoding_backend != "off": + if self.structured_outputs_config.guided_decoding_backend != "off": # TODO: speculative decoding support guided_decoding assert ( self.speculative_config.method is None diff --git a/fastdeploy/engine/args_utils.py b/fastdeploy/engine/args_utils.py index 123b33cfde..49d7394345 100644 --- a/fastdeploy/engine/args_utils.py +++ b/fastdeploy/engine/args_utils.py @@ -36,6 +36,7 @@ PoolerConfig, RunnerOption, SpeculativeConfig, + StructuredOutputsConfig, TaskOption, ) from fastdeploy.platforms import current_platform @@ -1063,7 +1064,7 @@ def create_engine_config(self, port_availability_check=True) -> FDConfig: early_stop_cfg = self.create_early_stop_config() early_stop_cfg.update_enable_early_stop(self.enable_early_stop) - + structured_outputs_config: StructuredOutputsConfig = StructuredOutputsConfig(args=all_dict) if port_availability_check: assert is_port_available( "0.0.0.0", int(self.engine_worker_queue_port[parallel_cfg.local_data_parallel_id]) @@ -1077,11 +1078,11 @@ def create_engine_config(self, port_availability_check=True) -> FDConfig: load_config=load_cfg, parallel_config=parallel_cfg, speculative_config=speculative_cfg, + structured_outputs_config=structured_outputs_config, ips=self.ips, use_warmup=self.use_warmup, limit_mm_per_prompt=self.limit_mm_per_prompt, mm_processor_kwargs=self.mm_processor_kwargs, - reasoning_parser=self.reasoning_parser, tool_parser=self.tool_call_parser, innode_prefill_ports=self.innode_prefill_ports, max_num_partial_prefills=self.max_num_partial_prefills, @@ -1089,7 +1090,5 @@ def create_engine_config(self, port_availability_check=True) -> FDConfig: long_prefill_token_threshold=self.long_prefill_token_threshold, graph_opt_config=graph_opt_cfg, plas_attention_config=plas_attention_config, - guided_decoding_backend=self.guided_decoding_backend, - disable_any_whitespace=self.guided_decoding_disable_any_whitespace, early_stop_config=early_stop_cfg, ) diff --git a/fastdeploy/engine/common_engine.py b/fastdeploy/engine/common_engine.py index 3be61af36f..2555925196 100644 --- a/fastdeploy/engine/common_engine.py +++ b/fastdeploy/engine/common_engine.py @@ -127,10 +127,10 @@ def __init__(self, cfg, start_queue=True): ) self.guided_decoding_checker = None - if self.cfg.guided_decoding_backend != "off": + if self.cfg.structured_outputs_config.guided_decoding_backend != "off": self.guided_decoding_checker = schema_checker( - self.cfg.guided_decoding_backend, - disable_any_whitespace=self.cfg.disable_any_whitespace, + self.cfg.structured_outputs_config.guided_decoding_backend, + disable_any_whitespace=self.cfg.structured_outputs_config.disable_any_whitespace, ) self._init_worker_monitor_signals() diff --git a/fastdeploy/engine/engine.py b/fastdeploy/engine/engine.py index ca75456670..4dec5b982b 100644 --- a/fastdeploy/engine/engine.py +++ b/fastdeploy/engine/engine.py @@ -90,7 +90,7 @@ def __init__(self, cfg): self.input_processor = InputPreprocessor( cfg.tokenizer, - cfg.reasoning_parser, + cfg.structured_outputs_config.reasoning_parser, cfg.limit_mm_per_prompt, cfg.mm_processor_kwargs, cfg.model_config.enable_mm, @@ -128,7 +128,7 @@ def start(self, api_server_pid=None): # If block numer is specified and model is deployed in mixed mode, start cache manager first if not self.do_profile and self.cfg.scheduler_config.splitwise_role != "mixed": - device_ids = self.cfg.device_ids.split(",") + device_ids = self.cfg.parallel_config.device_ids.split(",") self.cache_manager_processes = self.engine.start_cache_service(device_ids, self.ipc_signal_suffix, True) # Start workers @@ -162,7 +162,7 @@ def check_worker_initialize_status_func(res: dict): if self.do_profile: self._stop_profile() elif self.cfg.cache_config.enable_prefix_caching: - device_ids = self.cfg.device_ids.split(",") + device_ids = self.cfg.parallel_config.device_ids.split(",") self.cache_manager_processes = self.engine.start_cache_service(device_ids, self.ipc_signal_suffix, False) # Launch components: scheduler, cache_manager, expert_service et.al. @@ -426,7 +426,7 @@ def _setting_environ_variables(self): """ variables = { "ENABLE_FASTDEPLOY_LOAD_MODEL_CONCURRENCY": 0, - "LOAD_STATE_DICT_THREAD_NUM": len(self.cfg.device_ids.split(",")), + "LOAD_STATE_DICT_THREAD_NUM": len(self.cfg.parallel_config.device_ids.split(",")), "PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION": "python", "FLAGS_use_append_attn": 1, "NCCL_ALGO": "Ring", @@ -503,11 +503,11 @@ def _start_worker_service(self): if self.cfg.ips is not None: ips = ",".join(self.cfg.ips) arguments = ( - f" --devices {self.cfg.device_ids} {py_script}" + f" --devices {self.cfg.parallel_config.device_ids} {py_script}" f" --max_num_seqs {self.cfg.scheduler_config.max_num_seqs} --max_model_len {self.cfg.model_config.max_model_len}" f" --gpu_memory_utilization {self.cfg.cache_config.gpu_memory_utilization}" f" --model {self.cfg.model_config.model!s}" - f" --device_ids {self.cfg.device_ids}" + f" --device_ids {self.cfg.parallel_config.device_ids}" f" --tensor_parallel_size {self.cfg.parallel_config.tensor_parallel_size}" f" --engine_worker_queue_port {ports}" f" --pod_ip {self.cfg.master_ip}" @@ -527,10 +527,10 @@ def _start_worker_service(self): f" --think_end_id {self.cfg.model_config.think_end_id}" f" --speculative_config '{self.cfg.speculative_config.to_json_string()}'" f" --graph_optimization_config '{self.cfg.graph_opt_config.to_json_string()}'" - f" --guided_decoding_backend {self.cfg.guided_decoding_backend}" + f" --guided_decoding_backend {self.cfg.structured_outputs_config.guided_decoding_backend}" f" --load_strategy {self.cfg.load_config.load_strategy}" f" --early_stop_config '{self.cfg.early_stop_config.to_json_string()}'" - f" --reasoning_parser {self.cfg.reasoning_parser}" + f" --reasoning_parser {self.cfg.structured_outputs_config.reasoning_parser}" f" --load_choices {self.cfg.load_config.load_choices}" f" --plas_attention_config '{self.cfg.plas_attention_config.to_json_string()}'" f" --ips {ips}" @@ -546,7 +546,7 @@ def _start_worker_service(self): "enable_chunked_prefill": self.cfg.cache_config.enable_chunked_prefill, "do_profile": self.do_profile, "dynamic_load_weight": self.cfg.load_config.dynamic_load_weight, - "disable_any_whitespace": self.cfg.disable_any_whitespace, + "disable_any_whitespace": self.cfg.structured_outputs_config.disable_any_whitespace, "disable_custom_all_reduce": self.cfg.parallel_config.disable_custom_all_reduce, "enable_logprob": self.cfg.model_config.enable_logprob, "lm_head_fp32": self.cfg.model_config.lm_head_fp32, @@ -643,7 +643,7 @@ def _stop_profile(self): self.cfg.cache_config.reset(num_gpu_blocks) self.engine.resource_manager.reset_cache_config(self.cfg.cache_config) if self.cfg.cache_config.enable_prefix_caching or self.cfg.scheduler_config.splitwise_role != "mixed": - device_ids = self.cfg.device_ids.split(",") + device_ids = self.cfg.parallel_config.device_ids.split(",") self.cache_manager_processes = self.engine.start_cache_service( device_ids, self.ipc_signal_suffix, self.cfg.scheduler_config.splitwise_role != "mixed" ) diff --git a/fastdeploy/engine/expert_service.py b/fastdeploy/engine/expert_service.py index f3573d2237..dee8a4323b 100644 --- a/fastdeploy/engine/expert_service.py +++ b/fastdeploy/engine/expert_service.py @@ -53,7 +53,7 @@ def __init__(self, cfg, local_data_parallel_id, start_queue=True): end_pos = start_pos + self.cfg.parallel_config.tensor_parallel_size if cfg.scheduler_config.splitwise_role != "mixed": self.cfg.cache_config.rdma_comm_ports = self.cfg.cache_config.rdma_comm_ports[start_pos:end_pos] - self.cfg.local_device_ids = self.cfg.device_ids.split(",")[start_pos:end_pos] + self.cfg.local_device_ids = self.cfg.parallel_config.device_ids.split(",")[start_pos:end_pos] llm_logger.info(f"local_data_parallel_id: {local_data_parallel_id}") self.cfg.disaggregate_info = None diff --git a/fastdeploy/model_executor/guided_decoding/__init__.py b/fastdeploy/model_executor/guided_decoding/__init__.py index 9336f4a04e..dbfc70215d 100644 --- a/fastdeploy/model_executor/guided_decoding/__init__.py +++ b/fastdeploy/model_executor/guided_decoding/__init__.py @@ -41,7 +41,7 @@ def get_guided_backend( Raises: ValueError: If the specified backend is not supported """ - if fd_config.parallel_config.guided_decoding_backend.lower() == "xgrammar": + if fd_config.structured_outputs_config.guided_decoding_backend.lower() == "xgrammar": from fastdeploy.model_executor.guided_decoding.xgrammar_backend import ( XGrammarBackend, ) @@ -52,7 +52,7 @@ def get_guided_backend( ) else: raise ValueError( - f"Get unsupported backend {fd_config.parallel_config.guided_decoding_backend}," + f"Get unsupported backend {fd_config.structured_outputs_config.guided_decoding_backend}," f" please check your configuration." ) diff --git a/fastdeploy/model_executor/guided_decoding/base_guided_decoding.py b/fastdeploy/model_executor/guided_decoding/base_guided_decoding.py index b9a879e32d..57fccc3fe8 100644 --- a/fastdeploy/model_executor/guided_decoding/base_guided_decoding.py +++ b/fastdeploy/model_executor/guided_decoding/base_guided_decoding.py @@ -142,9 +142,9 @@ def __init__(self, fd_config: FDConfig): self.reasoning_parser = None self.hf_tokenizer = self._get_tokenizer_hf() - if self.fd_config.model_config.reasoning_parser: + if self.fd_config.structured_outputs_config.reasoning_parser: reasoning_parser_obj = ReasoningParserManager.get_reasoning_parser( - self.fd_config.model_config.reasoning_parser + self.fd_config.structured_outputs_config.reasoning_parser ) self.reasoning_parser = reasoning_parser_obj(self.hf_tokenizer) diff --git a/fastdeploy/model_executor/guided_decoding/xgrammar_backend.py b/fastdeploy/model_executor/guided_decoding/xgrammar_backend.py index 4a72ccf3e7..6681bf95f8 100644 --- a/fastdeploy/model_executor/guided_decoding/xgrammar_backend.py +++ b/fastdeploy/model_executor/guided_decoding/xgrammar_backend.py @@ -212,7 +212,7 @@ def __init__( self.vocab_size = fd_config.model_config.vocab_size self.batch_size = fd_config.scheduler_config.max_num_seqs - self.any_whitespace = not fd_config.parallel_config.disable_any_whitespace + self.any_whitespace = not fd_config.structured_outputs_config.disable_any_whitespace try: tokenizer_info = TokenizerInfo.from_huggingface(self.hf_tokenizer, vocab_size=self.vocab_size) diff --git a/fastdeploy/splitwise/splitwise_connector.py b/fastdeploy/splitwise/splitwise_connector.py index 62d33f433f..01d1c50c0f 100644 --- a/fastdeploy/splitwise/splitwise_connector.py +++ b/fastdeploy/splitwise/splitwise_connector.py @@ -375,7 +375,7 @@ def send_cache_infos(self, tasks, current_id): if tasks[i].disaggregate_info["transfer_protocol"] == "ipc": cache_info = { "request_id": tasks[i].request_id, - "device_ids": self.cfg.device_ids.split(","), + "device_ids": self.cfg.parallel_config.device_ids.split(","), "transfer_protocol": "ipc", "dest_block_ids": tasks[i].disaggregate_info["block_tables"], } @@ -395,7 +395,7 @@ def send_cache_infos(self, tasks, current_id): else: cache_info = { "request_id": tasks[i].request_id, - "device_ids": self.cfg.device_ids.split(","), + "device_ids": self.cfg.parallel_config.device_ids.split(","), "ip": self.cfg.host_ip, "rdma_ports": self.cfg.disaggregate_info["cache_info"]["rdma"]["rdma_port"], "transfer_protocol": "rdma", diff --git a/fastdeploy/worker/gcu_model_runner.py b/fastdeploy/worker/gcu_model_runner.py index b8351625a1..dd77fcbbf7 100644 --- a/fastdeploy/worker/gcu_model_runner.py +++ b/fastdeploy/worker/gcu_model_runner.py @@ -73,7 +73,7 @@ def __init__( self.enable_logprob = fd_config.model_config.enable_logprob self.guided_backend = None - if self.fd_config.parallel_config.guided_decoding_backend != "off": + if self.fd_config.structured_outputs_config.guided_decoding_backend != "off": self.guided_backend = get_guided_backend(fd_config=self.fd_config) # Sampler diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index bd39d6efb5..d872ff63c3 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -132,7 +132,7 @@ def __init__( self.sampler = SpeculativeSampler(fd_config) self.guided_backend = None - if self.fd_config.parallel_config.guided_decoding_backend != "off": + if self.fd_config.structured_outputs_config.guided_decoding_backend != "off": self.guided_backend = get_guided_backend(fd_config=self.fd_config) self.sampler.set_reasoning_parser(self.guided_backend.get_reasoning_parser()) diff --git a/fastdeploy/worker/hpu_model_runner.py b/fastdeploy/worker/hpu_model_runner.py index 56f84fd86d..c3cc776584 100644 --- a/fastdeploy/worker/hpu_model_runner.py +++ b/fastdeploy/worker/hpu_model_runner.py @@ -318,7 +318,7 @@ def __init__( self.speculative_decoding = self.speculative_method is not None self.guided_backend = None - if self.fd_config.parallel_config.guided_decoding_backend != "off": + if self.fd_config.structured_outputs_config.guided_decoding_backend != "off": self.guided_backend = get_guided_backend(fd_config=self.fd_config) # Sampler diff --git a/fastdeploy/worker/metax_model_runner.py b/fastdeploy/worker/metax_model_runner.py index dcce154ea5..67ac10047d 100644 --- a/fastdeploy/worker/metax_model_runner.py +++ b/fastdeploy/worker/metax_model_runner.py @@ -81,7 +81,7 @@ def __init__( self.enable_early_stop = self.fd_config.early_stop_config.enable_early_stop self.guided_backend = None - if self.fd_config.parallel_config.guided_decoding_backend != "off": + if self.fd_config.structured_outputs_config.guided_decoding_backend != "off": self.guided_backend = get_guided_backend(fd_config=self.fd_config) # VL model config: diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index 0f27fde5cb..19b9bed3ff 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -39,6 +39,7 @@ ParallelConfig, PlasAttentionConfig, SpeculativeConfig, + StructuredOutputsConfig, ) from fastdeploy.input.ernie4_5_tokenizer import Ernie4_5Tokenizer from fastdeploy.inter_communicator import EngineWorkerQueue as TaskQueue @@ -744,6 +745,8 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig: early_stop_config = EarlyStopConfig(args.early_stop_config) + structured_outputs_config: StructuredOutputsConfig = StructuredOutputsConfig(args=vars(args)) + # Note(tangbinhan): used for load_checkpoint model_config.pretrained_config.tensor_parallel_rank = parallel_config.tensor_parallel_rank model_config.pretrained_config.tensor_parallel_degree = parallel_config.tensor_parallel_size @@ -792,7 +795,7 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig: if not current_platform.is_cuda() and not current_platform.is_xpu(): logger.info("Set ENABLE_V1_KVCACHE_SCHEDULER to 0 due to not supported.") envs.ENABLE_V1_KVCACHE_SCHEDULER = 0 - if parallel_config.guided_decoding_backend != "off": + if structured_outputs_config.guided_decoding_backend != "off": logger.info("Set ENABLE_V1_KVCACHE_SCHEDULER to 0 due to not supported guided_decoding.") envs.ENABLE_V1_KVCACHE_SCHEDULER = 0 @@ -813,6 +816,7 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig: scheduler_config=scheduler_config, ips=args.ips, plas_attention_config=plas_attention_config, + structured_outputs_config=structured_outputs_config, ) update_fd_config_for_mm(fd_config) if fd_config.load_config.load_choices == "default_v1" and not v1_loader_support(fd_config):