Skip to content

Commit 084412a

Browse files
mamtsingochougul
authored andcommitted
Enable CB for GptOssModel
Signed-off-by: Mamta Singh <[email protected]> Signed-off-by: Onkar Chougule <[email protected]>
1 parent 56a616e commit 084412a

File tree

3 files changed

+39
-20
lines changed

3 files changed

+39
-20
lines changed

QEfficient/transformers/cache_utils.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -609,16 +609,28 @@ def update(
609609
position_ids = cache_kwargs.get("position_ids")
610610
is_sliding_layer = cache_kwargs.get("is_sliding")
611611
sliding_window = cache_kwargs.get("sliding_window")
612+
batch_index = cache_kwargs.get("batch_index", None) # Check and fetch batch index value from the kwargs
612613

613614
if is_sliding_layer:
614615
kv_position_ids = torch.where(position_ids == -1, position_ids, position_ids % sliding_window)
615616
else:
616617
kv_position_ids = position_ids
617618

618-
self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], kv_position_ids, key_states)
619-
self.value_cache[layer_idx] = CtxScatterFunc.apply(
620-
self.value_cache[layer_idx], kv_position_ids, value_states
621-
)
619+
if batch_index is not None:
620+
invalid_scatter_index = torch.iinfo(torch.int32).max
621+
scatter_position_ids = torch.where(kv_position_ids < 0, invalid_scatter_index, kv_position_ids)
622+
self.key_cache[layer_idx] = CtxScatterFuncCB.apply(
623+
self.key_cache[layer_idx], batch_index, scatter_position_ids, key_states
624+
)
625+
self.value_cache[layer_idx] = CtxScatterFuncCB.apply(
626+
self.value_cache[layer_idx], batch_index, scatter_position_ids, value_states
627+
)
628+
else:
629+
self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], kv_position_ids, key_states)
630+
self.value_cache[layer_idx] = CtxScatterFunc.apply(
631+
self.value_cache[layer_idx], kv_position_ids, value_states
632+
)
633+
622634
k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx]
623635

624636
# Original Gather
@@ -632,7 +644,12 @@ def update(
632644
invalid_idx_value = 0
633645
ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices)
634646

635-
k_out = CtxGatherFunc.apply(k_out, ctx_indices)
636-
v_out = CtxGatherFunc.apply(v_out, ctx_indices)
647+
if batch_index is not None:
648+
k_out = CtxGatherFuncCB.apply(k_out, batch_index, ctx_indices)
649+
v_out = CtxGatherFuncCB.apply(v_out, batch_index, ctx_indices)
650+
else:
651+
k_out = CtxGatherFunc.apply(k_out, ctx_indices)
652+
v_out = CtxGatherFunc.apply(v_out, ctx_indices)
653+
637654
v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out)
638655
return k_out, v_out

QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -428,9 +428,6 @@ def forward(
428428
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
429429
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
430430

431-
# kv_seq_len = key_states.shape[-2]
432-
433-
# kv_seq_len = past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
434431
cos, sin = self.rotary_emb(value_states, seq_len=32 * 1024)
435432
query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
436433

@@ -508,6 +505,7 @@ def forward(
508505
hidden_states = self.post_attention_layernorm(hidden_states)
509506
hidden_states, _ = self.mlp(hidden_states) # diff with llama: router scores
510507
# alth, _ = self.mlp.alt_forward(hidden_states)
508+
hidden_states = hidden_states.reshape(residual.shape)
511509
hidden_states = residual + hidden_states
512510
outputs = (hidden_states,)
513511

QEfficient/utils/generate_inputs.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,8 @@ def prepare_pytorch_inputs(self):
8787

8888
if self.full_batch_size:
8989
inputs["input_ids"] = input_ids
90-
inputs["position_ids"] = torch.arange(input_len).view(1, input_len)
91-
inputs["batch_index"] = torch.arange(1).view(-1, 1)
90+
inputs["position_ids"] = position_ids
91+
inputs["batch_index"] = torch.arange(self.full_batch_size).view(-1, 1)
9292

9393
past_key_values = []
9494
sliding_padding_shape = self.padding_shape[:2] + [self.config.sliding_window] + [self.padding_shape[-1]]
@@ -117,18 +117,15 @@ def update_pytorch_inputs(self, inputs, pt_outputs):
117117
"""
118118
updated_inputs = {}
119119
if self.full_batch_size:
120-
batch_index = torch.arange(1).view(-1, 1)
121-
122120
input_ids = pt_outputs.logits.detach().argmax(2)
123121
updated_inputs["input_ids"] = torch.full((self.full_batch_size, 1), self.tokenizer.pad_token_id)
124-
updated_inputs["input_ids"][batch_index.view(-1)] = input_ids
122+
updated_inputs["input_ids"][inputs["batch_index"].view(-1)] = input_ids
125123

126124
position_ids = inputs["position_ids"].max(1, keepdim=True).values + 1
127125
updated_inputs["position_ids"] = torch.full((self.full_batch_size, 1), 0)
128-
updated_inputs["position_ids"][batch_index.view(-1)] = position_ids
129-
130-
updated_inputs["batch_index"] = torch.arange(self.full_batch_size).view(-1, 1)
126+
updated_inputs["position_ids"][inputs["batch_index"].view(-1)] = position_ids
131127

128+
updated_inputs["batch_index"] = inputs["batch_index"]
132129
else:
133130
updated_inputs["input_ids"] = pt_outputs["logits"].argmax(-1).reshape(-1, 1)
134131
updated_inputs["position_ids"] = inputs["position_ids"].max(1, keepdim=True).values + 1
@@ -172,9 +169,15 @@ def prepare_ort_inputs(self):
172169
inputs["past_key." + str(i)] = np.zeros((cache_shape), dtype=np.float32)
173170
inputs["past_value." + str(i)] = np.zeros((cache_shape), dtype=np.float32)
174171
else:
172+
sliding_padding_shape = self.padding_shape[:2] + [self.config.sliding_window] + [self.padding_shape[-1]]
175173
for i in range(self.n_layer):
176-
inputs["past_key." + str(i)] = np.zeros((self.padding_shape), dtype=np.float32)
177-
inputs["past_value." + str(i)] = np.zeros((self.padding_shape), dtype=np.float32)
174+
pad_shape = (
175+
sliding_padding_shape if self.config.layer_types[i] == "sliding_attention" else self.padding_shape
176+
)
177+
inputs["past_key." + str(i)] = np.zeros((pad_shape), dtype=np.float32)
178+
inputs["past_value." + str(i)] = np.zeros((pad_shape), dtype=np.float32)
179+
if self.full_batch_size:
180+
inputs["batch_index"] = np.arange(self.full_batch_size).reshape(-1, 1)
178181
return inputs
179182

180183
def update_ort_inputs(self, inputs, ort_outputs):
@@ -195,7 +198,8 @@ def update_ort_inputs(self, inputs, ort_outputs):
195198
for i in range(self.n_layer):
196199
updated_inputs["past_key." + str(i)] = ort_outputs["past_key_values"][i * 2]
197200
updated_inputs["past_value." + str(i)] = ort_outputs["past_key_values"][i * 2 + 1]
198-
201+
if self.full_batch_size:
202+
updated_inputs["batch_index"] = inputs["batch_index"]
199203
return updated_inputs
200204

201205
def update_ort_outputs(self, ort_outputs):

0 commit comments

Comments
 (0)