@@ -331,9 +331,8 @@ def build(
331331 padded_num_tokens = common_attn_metadata .num_input_tokens - num_actual_tokens
332332 seq_lens = torch .cat ([
333333 seq_lens ,
334- torch .ones (padded_num_tokens ,
335- dtype = seq_lens .dtype ,
336- device = seq_lens .device )
334+ torch .tensor ([padded_num_tokens
335+ ]).to (seq_lens .device ).to (seq_lens .dtype )
337336 ])
338337 block_table_padding = torch .zeros (
339338 (padded_num_tokens , ) + block_table .shape [1 :],
@@ -342,10 +341,8 @@ def build(
342341 block_table = torch .cat ([block_table , block_table_padding ], dim = 0 )
343342 query_start_loc_cpu = torch .cat ([
344343 query_start_loc_cpu ,
345- torch .arange (query_start_loc_cpu [- 1 ] + 1 ,
346- query_start_loc_cpu [- 1 ] + padded_num_tokens ,
347- dtype = query_start_loc_cpu .dtype ,
348- device = query_start_loc_cpu .device )
344+ torch .tensor ([query_start_loc_cpu [- 1 ] + padded_num_tokens ]).to (
345+ query_start_loc_cpu .device ).to (query_start_loc_cpu .dtype )
349346 ])
350347
351348 query_start_loc = query_start_loc_cpu .to (self .device ,
@@ -621,7 +618,6 @@ def full_graph_attention(self,
621618 actual_seq_lengths_kv = attn_metadata .seq_lens_list
622619
623620 num_tokens = attn_metadata .query_start_loc_list [- 1 ]
624- query = query [:num_tokens ]
625621 graph_params = get_graph_params ()
626622 query_start_loc = attn_metadata .query_start_loc_list
627623 # Prepare tensors for attention output
0 commit comments