33import gc
44import itertools
55import math
6+ from functools import partial
67from typing import TYPE_CHECKING , Any , Dict , List , Optional , Type , Union , cast
78
89import habana_frameworks .torch as htorch
@@ -42,6 +43,13 @@ class HpuModelAdapterEncoderDecoder(HpuModelAdapter):
4243 def __init__ (self , model , vllm_config , layer_names , is_causal , sampler ):
4344 super ().__init__ (model , vllm_config , layer_names , is_causal , sampler )
4445
46+ # We only wrap the language model in HPU graph because some Ops in
47+ # vision model will fallback to CPU and cause the graph building fail.
48+ if htorch .utils .internal .is_lazy () and hasattr (self .model ,
49+ "language_model" ):
50+ self .model .language_model = htorch .hpu .wrap_in_hpu_graph (
51+ self .model .language_model , disable_tensor_cache = True )
52+
4553 def _set_cross_block_mapping (self , metadata , batch_size , device , dtype ):
4654 mask = torch .arange (0 ,
4755 self .block_size ,
@@ -110,6 +118,13 @@ def forward(self, *args, **kwargs):
110118 kwargs ['attn_metadata' ] = self ._update_cross_metadata (
111119 kwargs ['attn_metadata' ], input_ids .size (0 ), input_ids .size (1 ),
112120 input_ids .device , self .dtype )
121+ if htorch .utils .internal .is_lazy () and hasattr (self .model ,
122+ "language_model" ):
123+ bypass_hpu_graphs = kwargs .get ('bypass_hpu_graphs' , False )
124+ self .model .language_model .forward = partial (
125+ self .model .language_model .forward ,
126+ attn_metadata = kwargs ['attn_metadata' ],
127+ bypass_hpu_graphs = bypass_hpu_graphs )
113128 # TODO: Change the input_ids to 1D to match the public vllm
114129 # implementation and avoid shape mismatch issues with some
115130 # models(i.e. Mllama). But currently this will cause graph
@@ -118,9 +133,9 @@ def forward(self, *args, **kwargs):
118133 virtual_engine = 0
119134 if 'virtual_engine' in kwargs :
120135 virtual_engine = kwargs .pop ('virtual_engine' )
136+ attn_metadata = kwargs .pop ('attn_metadata' )
121137 if 'kv_caches' in kwargs :
122138 kwargs .pop ('kv_caches' )
123- attn_metadata = kwargs .pop ("attn_metadata" )
124139 with set_forward_context (attn_metadata , self .vllm_config ,
125140 virtual_engine ):
126141 hidden_states = self .model (* args , ** kwargs )
@@ -193,11 +208,7 @@ def _flatten(self, in_list):
193208 return list (itertools .chain (* in_list ))
194209
195210 def _maybe_wrap_in_hpu_graph (self , * args , ** kwargs ):
196- return htorch .hpu .wrap_in_hpu_graph (
197- HpuModelAdapterEncoderDecoder (* args , ** kwargs ),
198- disable_tensor_cache = True ,
199- ) if htorch .utils .internal .is_lazy (
200- ) else HpuModelAdapterEncoderDecoder (* args , ** kwargs )
211+ return HpuModelAdapterEncoderDecoder (* args , ** kwargs )
201212
202213 def prepare_model_input (
203214 self ,
0 commit comments