Skip to content

Commit 4d14120

Browse files
committed
Fix rope and caching indexing
1 parent adb05d9 commit 4d14120

File tree

4 files changed

+71
-25
lines changed

4 files changed

+71
-25
lines changed

keras_hub/src/models/smollm3/smollm3_backbone.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
)
88
from keras_hub.src.models.backbone import Backbone
99
from keras_hub.src.models.smollm3.smollm3_layers import SmolLM3DecoderLayer
10-
from keras_hub.src.models.smollm3.smollm3_layers import SmolLM3RotaryEmbedding
1110

1211

1312
@keras_hub_export(
@@ -91,6 +90,9 @@ def __init__(
9190
intermediate_size=intermediate_dim,
9291
mlp_bias=mlp_bias,
9392
layer_norm_epsilon=layer_norm_epsilon,
93+
max_position_embeddings=max_position_embeddings,
94+
rope_theta=rope_theta,
95+
partial_rotary_factor=partial_rotary_factor,
9496
name=f"transformer_layer_{i}",
9597
)
9698
self.transformer_layers.append(layer)

keras_hub/src/models/smollm3/smollm3_causal_lm.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -85,18 +85,22 @@ def call_with_cache(
8585
logits = self.backbone.token_embedding(x, reverse=True)
8686
return logits, hidden_states, cache
8787

88-
def _build_cache(self, token_ids):
89-
"""Build an empty cache for use with `call_with_cache()`."""
88+
def _build_cache(self, token_ids, cache_max_length):
89+
"""Build an empty cache for use with `call_with_cache()`.
90+
91+
Args:
92+
token_ids: Prompt tokens to seed the cache with
93+
cache_max_length: Maximum length for the cache (should be max generation length)
94+
"""
9095
batch_size = ops.shape(token_ids)[0]
91-
max_length = ops.shape(token_ids)[1]
9296
num_layers = self.backbone.num_layers
9397
num_key_value_heads = self.backbone.num_key_value_heads
9498
head_dim = self.backbone.hidden_dim // self.backbone.num_attention_heads
9599
shape = [
96100
batch_size,
97101
num_layers,
98102
2,
99-
max_length,
103+
cache_max_length,
100104
num_key_value_heads,
101105
head_dim,
102106
]
@@ -126,17 +130,23 @@ def generate_step(
126130
"""
127131
token_ids, padding_mask = inputs["token_ids"], inputs["padding_mask"]
128132

129-
hidden_states, cache = self._build_cache(token_ids)
130133
# Compute the lengths of all user inputted tokens ids.
131134
row_lengths = ops.sum(ops.cast(padding_mask, "int32"), axis=-1)
132135
# Start at the first index that has no user inputted id.
133136
index = ops.min(row_lengths)
134137

138+
# Only pass actual prompt tokens to _build_cache, not padding
139+
# But cache must be sized for the full max_length
140+
max_length = ops.shape(token_ids)[1]
141+
prompt_token_ids = token_ids[:, :index]
142+
hidden_states, cache = self._build_cache(prompt_token_ids, max_length)
143+
135144
def next(prompt, cache, index):
136145
# The cache index is the index of our previous token.
137146
cache_update_index = index - 1
138147
batch_size = ops.shape(prompt)[0]
139148
prompt = ops.slice(prompt, [0, cache_update_index], [batch_size, 1])
149+
140150
logits, hidden_states, cache = self.call_with_cache(
141151
prompt,
142152
cache,

keras_hub/src/models/smollm3/smollm3_layers.py

Lines changed: 49 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
merge_padding_and_attention_mask,
1111
)
1212
from keras_hub.src.models.smollm3.smollm3_utils import rope_init
13-
from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding
13+
from keras_hub.src.models.smollm3.smollm3_utils import apply_rotary_pos_emb
1414
import math
1515

1616

@@ -39,6 +39,9 @@ def __init__(
3939
rope_layer_enabled_list: list[bool],
4040
layer_types: list[str],
4141
layer_idx: int,
42+
max_position_embeddings: int = 2048,
43+
rope_theta: float = 10000.0,
44+
partial_rotary_factor: float = 1.0,
4245
**kwargs,
4346
):
4447
super().__init__(**kwargs)
@@ -50,19 +53,17 @@ def __init__(
5053
self.attention_dropout = attention_dropout
5154
self.rope_layer_enabled_list = rope_layer_enabled_list
5255
self.layer_types = layer_types
56+
self.max_position_embeddings = max_position_embeddings
57+
self.rope_theta = rope_theta
58+
self.partial_rotary_factor = partial_rotary_factor
59+
5360
self._dot_product_equation = "bquh,bkuh->buqk"
5461
self._combine_equation = "buqk,bkuh->bquh"
5562

5663
self.head_dim = hidden_size // self.num_attention_heads
5764
self._inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
5865

59-
self.rotary_embedding = RotaryEmbedding(
60-
max_wavelength=5000000.0,
61-
)
62-
6366
self.layer_idx = layer_idx
64-
65-
self.head_dim = self.hidden_size // self.num_attention_heads
6667
self.num_key_value_groups = (
6768
self.num_attention_heads // self.num_key_value_heads
6869
)
@@ -97,6 +98,15 @@ def __init__(
9798
else True
9899
) # Default to True if index out of bounds
99100

101+
self.rotary_embedding = SmolLM3RotaryEmbedding(
102+
hidden_size=self.hidden_size,
103+
num_attention_heads=self.num_attention_heads,
104+
max_position_embeddings=self.max_position_embeddings,
105+
rope_theta=self.rope_theta,
106+
partial_rotary_factor=self.partial_rotary_factor,
107+
name="rotary_emb",
108+
)
109+
100110
self._softmax = layers.Softmax(
101111
axis=-1,
102112
dtype="float32",
@@ -172,7 +182,15 @@ def _compute_kv_values(x_input):
172182
value = value_cache
173183
else:
174184
key_update, value_update = _compute_kv_values(hidden_states)
175-
start = [0, self_attention_cache_update_index, 0, 0]
185+
186+
# Apply RoPE to key_update BEFORE caching
187+
if self.use_rope:
188+
cos, sin = self.rotary_embedding(query, start_index=start_index)
189+
query_rope, key_update = apply_rotary_pos_emb(query, key_update, cos, sin, expansion_axis=2)
190+
query = query_rope
191+
192+
start = (0, self_attention_cache_update_index, 0, 0)
193+
176194
key = ops.slice_update(key_cache, start, key_update)
177195
value = ops.slice_update(
178196
value_cache, start, value_update
@@ -189,14 +207,13 @@ def _compute_kv_values(x_input):
189207
)
190208
key, value = _compute_kv_values(hidden_states)
191209

192-
if self.use_rope:
193-
query = self.rotary_embedding(query, start_index=start_index)
194-
key = self.rotary_embedding(key, start_index=start_index)
210+
# Apply RoPE when not using cache
211+
if self.use_rope:
212+
cos, sin = self.rotary_embedding(query, start_index=start_index)
213+
query, key = apply_rotary_pos_emb(query, key, cos, sin, expansion_axis=2)
195214

196-
print('pre', key.shape, value.shape)
197215
key = ops.repeat(key, repeats=self.num_key_value_groups, axis=2)
198216
value = ops.repeat(value, repeats=self.num_key_value_groups, axis=2)
199-
print('post', key.shape, value.shape)
200217

201218
attn_output = self._compute_attention(
202219
query,
@@ -400,6 +417,9 @@ def __init__(
400417
intermediate_size: int,
401418
mlp_bias: bool,
402419
layer_norm_epsilon: float,
420+
max_position_embeddings: int = 2048,
421+
rope_theta: float = 10000.0,
422+
partial_rotary_factor: float = 1.0,
403423
**kwargs,
404424
):
405425
super().__init__(**kwargs)
@@ -415,6 +435,9 @@ def __init__(
415435
rope_layer_enabled_list=rope_layer_enabled_list,
416436
layer_types=layer_types,
417437
layer_idx=layer_idx,
438+
max_position_embeddings=max_position_embeddings,
439+
rope_theta=rope_theta,
440+
partial_rotary_factor=partial_rotary_factor,
418441
name="self_attn",
419442
)
420443

@@ -641,26 +664,34 @@ def call(
641664
Shape can vary, but the last dimension is head_dim.
642665
position_ids: Tensor of position IDs of shape (batch_size, seq_len).
643666
"""
644-
inv_freq_expanded = ops.expand_dims(
645-
ops.expand_dims(self.inv_freq, axis=0), axis=-1
646-
)
647-
648667
batch_size = ops.shape(x)[0]
649668
seq_len = ops.shape(x)[1]
650669
positions = ops.arange(seq_len, dtype="float32")
651670
positions = positions + ops.cast(start_index, dtype="float32")
652671

672+
# inv_freq: (inv_freq_dim,) -> (1, inv_freq_dim, 1) -> (batch, inv_freq_dim, 1)
673+
inv_freq_expanded = ops.expand_dims(
674+
ops.expand_dims(self.inv_freq, axis=0), axis=-1
675+
)
653676
inv_freq_expanded = ops.broadcast_to(
654677
inv_freq_expanded, (batch_size, ops.shape(self.inv_freq)[0], 1)
655678
)
656679

657-
position_ids_expanded = ops.expand_dims(positions, axis=1).T
680+
# positions: (seq_len,) -> (1, 1, seq_len) -> (batch, 1, seq_len)
681+
position_ids_expanded = ops.expand_dims(
682+
ops.expand_dims(positions, axis=0), axis=0
683+
)
684+
position_ids_expanded = ops.broadcast_to(
685+
position_ids_expanded, (batch_size, 1, seq_len)
686+
)
658687

688+
# matmul: (batch, inv_freq_dim, 1) @ (batch, 1, seq_len) -> (batch, inv_freq_dim, seq_len)
659689
freqs = ops.matmul(
660690
ops.cast(inv_freq_expanded, "float32"),
661691
ops.cast(position_ids_expanded, "float32"),
662692
)
663693

694+
# transpose: (batch, inv_freq_dim, seq_len) -> (batch, seq_len, inv_freq_dim)
664695
freqs = ops.transpose(freqs, axes=(0, 2, 1))
665696

666697
emb = ops.concatenate((freqs, freqs), axis=-1)

keras_hub/src/utils/transformers/convert_smollm3.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,10 @@ def convert_backbone_config(transformers_config):
2727
"partial_rotary_factor": 1.0,
2828
"attention_bias": transformers_config["attention_bias"],
2929
"attention_dropout": transformers_config["attention_dropout"],
30-
"rope_layer_enabled_list": transformers_config["no_rope_layers"],
30+
# Despite the name, no_rope_layers: 1 = HAS RoPE, 0 = NO RoPE
31+
"rope_layer_enabled_list": [
32+
bool(x) for x in transformers_config["no_rope_layers"]
33+
],
3134
"layer_types": transformers_config["layer_types"],
3235
"mlp_bias": transformers_config["mlp_bias"]
3336
}

0 commit comments

Comments
 (0)