Skip to content

Commit a94acc9

Browse files
committed
Remove unnecessary trimming of cache padding
1 parent 4d14120 commit a94acc9

File tree

1 file changed

+5
-14
lines changed

1 file changed

+5
-14
lines changed

keras_hub/src/models/smollm3/smollm3_causal_lm.py

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -85,22 +85,18 @@ def call_with_cache(
8585
logits = self.backbone.token_embedding(x, reverse=True)
8686
return logits, hidden_states, cache
8787

88-
def _build_cache(self, token_ids, cache_max_length):
89-
"""Build an empty cache for use with `call_with_cache()`.
90-
91-
Args:
92-
token_ids: Prompt tokens to seed the cache with
93-
cache_max_length: Maximum length for the cache (should be max generation length)
94-
"""
88+
def _build_cache(self, token_ids):
89+
"""Build an empty cache for use with `call_with_cache()`."""
9590
batch_size = ops.shape(token_ids)[0]
91+
max_length = ops.shape(token_ids)[1]
9692
num_layers = self.backbone.num_layers
9793
num_key_value_heads = self.backbone.num_key_value_heads
9894
head_dim = self.backbone.hidden_dim // self.backbone.num_attention_heads
9995
shape = [
10096
batch_size,
10197
num_layers,
10298
2,
103-
cache_max_length,
99+
max_length,
104100
num_key_value_heads,
105101
head_dim,
106102
]
@@ -130,17 +126,12 @@ def generate_step(
130126
"""
131127
token_ids, padding_mask = inputs["token_ids"], inputs["padding_mask"]
132128

129+
hidden_states, cache = self._build_cache(token_ids)
133130
# Compute the lengths of all user inputted tokens ids.
134131
row_lengths = ops.sum(ops.cast(padding_mask, "int32"), axis=-1)
135132
# Start at the first index that has no user inputted id.
136133
index = ops.min(row_lengths)
137134

138-
# Only pass actual prompt tokens to _build_cache, not padding
139-
# But cache must be sized for the full max_length
140-
max_length = ops.shape(token_ids)[1]
141-
prompt_token_ids = token_ids[:, :index]
142-
hidden_states, cache = self._build_cache(prompt_token_ids, max_length)
143-
144135
def next(prompt, cache, index):
145136
# The cache index is the index of our previous token.
146137
cache_update_index = index - 1

0 commit comments

Comments
 (0)