1818                      mpi_comm , mpi_rank , nvtx_range_debug )
1919from  ..bindings  import  executor  as  tllm 
2020from  ..builder  import  ConfigEncoder , Engine , EngineConfig 
21- from  ..llmapi .llm_args  import  KvCacheConnectorConfig , PybindMirror , TorchLlmArgs 
21+ from  ..llmapi .llm_args  import  (BaseLlmArgs , KvCacheConnectorConfig ,
22+                                PybindMirror , TorchLlmArgs )
2223from  ..llmapi .mpi_session  import  set_mpi_session_cpp 
2324from  ..llmapi .tokenizer  import  TokenizerBase 
2425from  ..llmapi .tracer  import  VizTracer , global_tracer , set_global_tracer 
@@ -64,7 +65,7 @@ def __init__(
6465        kv_connector_config : Optional [KvCacheConnectorConfig ] =  None ,
6566        hf_model_dir : Optional [Path ] =  None ,
6667        tokenizer : Optional [TokenizerBase ] =  None ,
67-         llm_args : Optional [TorchLlmArgs ] =  None ,
68+         llm_args : Optional [BaseLlmArgs ] =  None ,
6869    ) ->  None :
6970        postproc_config  =  postproc_worker_config  or  PostprocWorkerConfig ()
7071        super ().__init__ (
@@ -107,40 +108,55 @@ def _get_comm_ranks_device_id():
107108            device_ids  =  mpi_comm ().allgather (device_id )
108109            return  comm_ranks , device_ids 
109110
110-         def  _create_py_executor (executor_config ):
111-             assert  executor_config  is  None , "expect an empty executor_config is _create_py_executor" 
112-             executor_config  =  llm_args .get_executor_config (
113-                 hf_model_dir , tokenizer )
114-             # Persist so downstream code (e.g., default max_tokens deduction) has access 
115-             self ._executor_config  =  executor_config 
116-             executor_config .logits_post_processor_config  =  tllm .LogitsPostProcessorConfig (
117-                 processor_batched = batched_logits_processor , replicate = False )
118-             comm_ranks , device_ids  =  _get_comm_ranks_device_id ()
119-             executor_config .parallel_config  =  tllm .ParallelConfig (
120-                 participant_ids = comm_ranks , device_ids = device_ids )
121-             args  =  {
122-                 "executor_config" : executor_config ,
123-                 "checkpoint_dir" : executor_config .hf_model_dir ,
124-             }
111+         def  _create_py_executor ():
112+             args  =  {}
125113            assert  hasattr (
126-                 executor_config , "backend" 
127-             ), "executor_config  should be with backend in _create_py_executor" 
128-             if  executor_config .backend  ==  "pytorch" :
114+                 self . llm_args , "backend" 
115+             ), "llm_args  should be with backend in _create_py_executor" 
116+             if  self . llm_args .backend  ==  "pytorch" :
129117                from  tensorrt_llm ._torch .pyexecutor .py_executor_creator  import  \
130118                    create_py_executor 
131119                create_executor  =  create_py_executor 
120+                 args ["llm_args" ] =  self .llm_args 
121+                 args ["checkpoint_dir" ] =  hf_model_dir 
122+                 args ["tokenizer" ] =  tokenizer 
132123                args ["lora_config" ] =  lora_config 
133-                 args [
134-                     "garbage_collection_gen0_threshold" ] =  llm_args .garbage_collection_gen0_threshold 
135124                args ["kv_connector_config" ] =  kv_connector_config 
136-             elif  executor_config .backend  ==  "_autodeploy" :
125+                 args [
126+                     "logits_post_processor_config" ] =  tllm .LogitsPostProcessorConfig (
127+                         processor_batched = batched_logits_processor ,
128+                         replicate = False )
129+                 comm_ranks , device_ids  =  _get_comm_ranks_device_id ()
130+                 args ["parallel_config" ] =  tllm .ParallelConfig (
131+                     participant_ids = comm_ranks , device_ids = device_ids )
132+             elif  self .llm_args .backend  ==  "_autodeploy" :
133+                 from  tensorrt_llm ._torch .auto_deploy .llm_args  import  \
134+                     LlmArgs  as  ADLlmArgs 
137135                from  tensorrt_llm ._torch .auto_deploy .shim .ad_executor  import  \
138136                    create_autodeploy_executor 
139137                create_executor  =  create_autodeploy_executor 
138+                 assert  isinstance (self .llm_args , ADLlmArgs )
139+                 args ["ad_config" ] =  self .llm_args .get_pytorch_backend_config ()
140140            else :
141141                raise  ValueError (
142-                     f"Unsupported backend config: { executor_config .backend }  )
143-             return  create_executor (** args )
142+                     f"Unsupported backend config: { self .llm_args .backend }  )
143+ 
144+             # Define additional attributes that can be used later, such as in _deduce_max_tokens 
145+             self .mapping  =  self .llm_args .parallel_config .to_mapping ()
146+             self .checkpoint_loader  =  None 
147+             if  self .llm_args .backend  ==  "pytorch" :
148+                 from  tensorrt_llm ._torch .pyexecutor .config  import  \
149+                     _construct_checkpoint_loader 
150+                 self .checkpoint_loader  =  _construct_checkpoint_loader (
151+                     self .llm_args .backend , self .llm_args .checkpoint_loader ,
152+                     self .llm_args .checkpoint_format )
153+ 
154+             _executor  =  create_executor (** args )
155+             self .max_seq_len  =  self .llm_args .max_seq_len 
156+             if  _executor .max_seq_len  is  not None :
157+                 # max_seq_len might be updated by model engine as in create_py_executor 
158+                 self .max_seq_len  =  _executor .max_seq_len 
159+             return  _executor 
144160
145161        def  _create_engine (executor_config ):
146162            if  executor_config  is  None :
@@ -164,8 +180,7 @@ def _create_engine(executor_config):
164180                                 executor_config )
165181
166182        self .engine  =  _create_py_executor (
167-             executor_config ) if  llm_args  is  not None  else  _create_engine (
168-                 executor_config )
183+         ) if  self .llm_args  is  not None  else  _create_engine (executor_config )
169184
170185        self ._lora_manager : Optional [LoraManager ] =  None 
171186        self ._prompt_adapter_manager : Optional [PromptAdapterManager ] =  None 
@@ -188,8 +203,9 @@ def _create_engine(executor_config):
188203            if  engine_config .build_config .max_prompt_embedding_table_size  >  0 :
189204                self ._prompt_adapter_manager  =  PromptAdapterManager ()
190205
191-         if  getattr (self ._executor_config , "backend" ,
192-                    "" ) ==  "pytorch"  and  lora_config  is  not None :
206+         if  self .llm_args  and  getattr (
207+                 self .llm_args , "backend" ,
208+                 "" ) ==  "pytorch"  and  lora_config  is  not None :
193209            from  tensorrt_llm ._torch .pyexecutor .resource_manager  import  \
194210                ResourceManagerType 
195211            peft_cache_manager  =  self .engine .resource_manager .resource_managers .get (
@@ -471,26 +487,43 @@ def _enqueue_request(self, request: GenerationRequest) -> int:
471487        assert  request .id  is  not None 
472488
473489        def  _deduce_max_tokens (request : GenerationRequest ,
474-                                executor_config : tllm .ExecutorConfig ) ->  int :
490+                                executor_config : tllm .ExecutorConfig ,
491+                                llm_args : Optional [BaseLlmArgs ] =  None ) ->  int :
475492            # deduce max_tokens when it's not set by user 
476493            max_tokens  =  request .sampling_params .max_tokens 
477494            query_token_len  =  len (
478495                request .query_token_ids ) if  request .query_token_ids  else  0 
479-             cp_size  =  1  if  (not  hasattr (executor_config , "mapping" )
480-                             or  executor_config .mapping .cp_size 
481-                             is  None ) else  executor_config .mapping .cp_size 
482-             if  not  hasattr (executor_config , "max_seq_len" ):
496+ 
497+             cp_size  =  1 
498+             max_seq_len  =  None 
499+             if  llm_args  is  not None :
500+                 # deduce max_tokens by llm args 
501+                 assert  executor_config  is  None , "An empty executor_config in _deduce_max_tokens is expected when LLM arguments are defined." 
502+                 if  hasattr (self ,
503+                            "mapping" ) and  self .mapping .cp_size  is  not None :
504+                     cp_size  =  self .mapping .cp_size 
505+                 max_seq_len  =  getattr (self , "max_seq_len" , None )
506+             else :
507+                 # deduce max_tokens by executor config 
508+                 if  hasattr (executor_config , "mapping" 
509+                            ) and  executor_config .mapping .cp_size  is  not None :
510+                     cp_size  =  executor_config .mapping .cp_size 
511+                 max_seq_len  =  getattr (executor_config , "max_seq_len" , None )
512+             if  max_seq_len  is  None :
483513                logger .warning ("`default_max_tokens` cannot be deduced" )
484514                if  max_tokens  is  None :
485515                    raise  ValueError (
486516                        "`max_tokens` must be set when `default_max_tokens` cannot be deduced" 
487517                    )
518+                 else :
519+                     # use max_tokens if can't deduce default_max_tokens 
520+                     return  max_tokens 
488521            splited_prompt_len  =  int (len (prompt_token_ids ) /  cp_size )
489-             default_max_tokens  =  executor_config . max_seq_len  -  splited_prompt_len  -  query_token_len 
522+             default_max_tokens  =  max_seq_len  -  splited_prompt_len  -  query_token_len 
490523            if  default_max_tokens  <=  0 :
491524                logger .warning (
492525                    f"`default_max_tokens` ({ default_max_tokens }  
493-                     f"`default_max_tokens` ({ default_max_tokens } { executor_config . max_seq_len }  
526+                     f"`default_max_tokens` ({ default_max_tokens } { max_seq_len }  
494527                    f" - `splited_prompt_len` ({ splited_prompt_len } { query_token_len }  
495528                )
496529                if  max_tokens  is  None :
@@ -512,7 +545,8 @@ def _deduce_max_tokens(request: GenerationRequest,
512545            executor_request  =  tllm .Request (
513546                client_id = request .id ,
514547                input_token_ids = prompt_token_ids ,
515-                 max_tokens = _deduce_max_tokens (request , self ._executor_config ),
548+                 max_tokens = _deduce_max_tokens (request , self ._executor_config ,
549+                                               self .llm_args ),
516550                streaming = request .streaming ,
517551                sampling_config = request .sampling_params ._get_sampling_config (),
518552                end_id = - 1  if  request .sampling_params .ignore_eos  else 
@@ -638,11 +672,19 @@ def shutdown(self):
638672            self .engine .shutdown ()
639673            self .engine  =  None 
640674
641-             if  hasattr (
642-                     self ._executor_config , "checkpoint_loader" 
643-             ) and  self ._executor_config .checkpoint_loader  is  not None :
644-                 self ._executor_config .checkpoint_loader .cleanup ()
645-                 self ._executor_config .checkpoint_loader  =  None 
675+             if  self .llm_args  is  not None :
676+                 assert  self ._executor_config  is  None , "An empty executor_config is expected in shutdown when LLM arguments are defined." 
677+                 if  (self .llm_args .backend  ==  "pytorch" 
678+                         and  hasattr (self , "checkpoint_loader" )
679+                         and  self .checkpoint_loader  is  not None ):
680+                     self .checkpoint_loader .cleanup ()
681+                     self .checkpoint_loader  =  None 
682+             else :
683+                 if  hasattr (
684+                         self ._executor_config , "checkpoint_loader" 
685+                 ) and  self ._executor_config .checkpoint_loader  is  not None :
686+                     self ._executor_config .checkpoint_loader .cleanup ()
687+                     self ._executor_config .checkpoint_loader  =  None 
646688
647689        # Check if there are any errors from the threads before shutdown. 
648690        self ._handle_background_error ()
@@ -689,7 +731,7 @@ def worker_main(
689731    kv_connector_config : Optional [KvCacheConnectorConfig ] =  None ,
690732    hf_model_dir : Optional [Path ] =  None ,
691733    tokenizer : Optional [TokenizerBase ] =  None ,
692-     llm_args : Optional [TorchLlmArgs ] =  None ,
734+     llm_args : Optional [BaseLlmArgs ] =  None ,
693735) ->  None :
694736    mpi_comm ().barrier ()
695737    print_colored_debug (f"Worker { mpi_rank ()} \n " ,
0 commit comments