@@ -85,22 +85,18 @@ def call_with_cache(
8585 logits = self .backbone .token_embedding (x , reverse = True )
8686 return logits , hidden_states , cache
8787
88- def _build_cache (self , token_ids , cache_max_length ):
89- """Build an empty cache for use with `call_with_cache()`.
90-
91- Args:
92- token_ids: Prompt tokens to seed the cache with
93- cache_max_length: Maximum length for the cache (should be max generation length)
94- """
88+ def _build_cache (self , token_ids ):
89+ """Build an empty cache for use with `call_with_cache()`."""
9590 batch_size = ops .shape (token_ids )[0 ]
91+ max_length = ops .shape (token_ids )[1 ]
9692 num_layers = self .backbone .num_layers
9793 num_key_value_heads = self .backbone .num_key_value_heads
9894 head_dim = self .backbone .hidden_dim // self .backbone .num_attention_heads
9995 shape = [
10096 batch_size ,
10197 num_layers ,
10298 2 ,
103- cache_max_length ,
99+ max_length ,
104100 num_key_value_heads ,
105101 head_dim ,
106102 ]
@@ -130,17 +126,12 @@ def generate_step(
130126 """
131127 token_ids , padding_mask = inputs ["token_ids" ], inputs ["padding_mask" ]
132128
129+ hidden_states , cache = self ._build_cache (token_ids )
133130 # Compute the lengths of all user inputted tokens ids.
134131 row_lengths = ops .sum (ops .cast (padding_mask , "int32" ), axis = - 1 )
135132 # Start at the first index that has no user inputted id.
136133 index = ops .min (row_lengths )
137134
138- # Only pass actual prompt tokens to _build_cache, not padding
139- # But cache must be sized for the full max_length
140- max_length = ops .shape (token_ids )[1 ]
141- prompt_token_ids = token_ids [:, :index ]
142- hidden_states , cache = self ._build_cache (prompt_token_ids , max_length )
143-
144135 def next (prompt , cache , index ):
145136 # The cache index is the index of our previous token.
146137 cache_update_index = index - 1
0 commit comments