Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
485308a
[QEff]: Add gpt_oss
vbaddi Aug 6, 2025
59e2115
nit: update transforms
vbaddi Aug 6, 2025
a6c2812
nit: add header to __init__
vbaddi Aug 6, 2025
8e5783e
apirunner change
ochougul Aug 7, 2025
5c3c971
added test along with simplified Hybridcache
ochougul Aug 7, 2025
ce53d3c
added test assert
ochougul Aug 7, 2025
e0bd90f
nit: update modeling and make transform uniform
vbaddi Aug 7, 2025
18795f2
nit: add changes from gpt_oss_swa branch
vbaddi Aug 7, 2025
30ed222
nit: update test gpt file
vbaddi Aug 8, 2025
14afedb
MOE optimized
ochougul Aug 8, 2025
99f4795
nit: update modeling with new decode moe forward
vbaddi Aug 11, 2025
8637f1f
simplified slidingwindow KV gather and attention is permutation invar…
ochougul Aug 19, 2025
87414a2
nit: seperate gate, up projections for MoE
vbaddi Aug 20, 2025
3adccf6
added MXFP4 quantizer support to directly load GPT-OSS models via QEF…
ochougul Oct 8, 2025
890eed7
nit: rebase to mainline and resolve conflicts
Oct 14, 2025
402f8cb
nit: add license details to mxfp4 quantizer
Oct 14, 2025
e557472
nit: rebase to mainline
vbaddi Oct 15, 2025
05c599d
nit: remove test file and add sample test in config
Oct 15, 2025
dfddd55
nit: remove streamer from .generate() api in example file
vbaddi Oct 15, 2025
04ae7d8
nit: device_ids typo in example script
vbaddi Oct 15, 2025
3cadfba
nit: fix model_name in tests
vbaddi Oct 15, 2025
21a6620
Enable CB for GptOssModel
mamtsing Nov 3, 2025
1bd5d83
Merge branch 'main' into add_gpt_oss
quic-mamta Nov 3, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion QEfficient/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
# Placeholder for all non-transformer models registered in QEfficient


# custom warning for the better logging experience
warnings.formatwarning = custom_format_warning

Expand Down
78 changes: 63 additions & 15 deletions QEfficient/base/pytorch_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,61 +120,109 @@ def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]:

class SplitGateUpWeightsTransform(PytorchTransform):
"""
split fused Gate+Up weights and copy into the model
Split fused Gate+Up weights and copy into the model.
Handles both standard MoE models and GptOss models.

For every transformer layer inside `model`:
• expects <PREFIX>.experts.gate_up_proj in the *source* `sd`
• copies halves into
<PREFIX>.experts.gate_proj <-- Gate [E,H,I]
<PREFIX>.experts.up_proj <-- Up [E,H,I]
• expects <PREFIX>.experts.gate_up_proj in the *source* `sd`
• copies halves into
<PREFIX>.experts.gate_proj <-- Gate [E,H,I]
<PREFIX>.experts.up_proj <-- Up [E,H,I]

Handles both interleaved weights (GptOss) and concatenated weights (standard MoE).
Also handles bias terms when present.
"""

@classmethod
def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]:
transformed = False
model_class = model.__class__.__name__ if hasattr(model, "model") else model.__class__.__name__

if model_class not in VLM_SPLIT_GATE_UP_WEIGHTS:
return model, transformed

model_tmp = model.language_model if hasattr(model, "language_model") else model

num_layers = len(model_tmp.model.layers)
delete_fused_key = True
sd = model_tmp.state_dict()

for layer_idx in range(num_layers):
# Determine if this is a GptOss model or standard MoE model
is_gpt_oss = hasattr(model_tmp.model.layers[layer_idx], "mlp")

# ---- build the textual prefix once per layer ----------
prefix = f"model.layers.{layer_idx}.feed_forward.experts."
if is_gpt_oss:
prefix = f"model.layers.{layer_idx}.mlp.experts."
experts = model_tmp.model.layers[layer_idx].mlp.experts
else:
prefix = f"model.layers.{layer_idx}.feed_forward.experts."
experts = model_tmp.model.layers[layer_idx].feed_forward.experts

fused_key = prefix + "gate_up_proj"
gate_key = prefix + "gate_proj"
up_key = prefix + "up_proj"

# ---- split [E,H,2I] → two [E,H,I] tensors ----------------------
fused = sd[fused_key] # [E, H, 2I] (no .weight here)
# Check if we have bias terms (GptOss case)
has_bias = fused_key + "_bias" in sd
if has_bias:
fused_bias_key = fused_key + "_bias"
gate_bias_key = gate_key + "_bias"
up_bias_key = up_key + "_bias"

# ---- split weights based on model type ----------------------
fused = sd[fused_key] # [E, H, 2I]
E, H, two_I = fused.shape
ffn_dim = two_I // 2
gate, up = fused.split(ffn_dim, dim=-1) # views – no copy

experts = model_tmp.model.layers[layer_idx].feed_forward.experts
if is_gpt_oss:
# For GptOss, gate/up are interleaved: [gate0, up0, gate1, up1, ...]
gate = fused[..., ::2] # [E, H, I] - even indices
up = fused[..., 1::2] # [E, H, I] - odd indices
else:
# For standard MoE, gate/up are concatenated: [gate, up]
ffn_dim = two_I // 2
gate, up = fused.split(ffn_dim, dim=-1) # views – no copy

# Copy weights to model
experts.gate_proj.data.copy_(gate)
experts.up_proj.data.copy_(up)

# Handle bias if present
if has_bias:
fused_bias = sd[fused_bias_key] # [E, 2I]

if is_gpt_oss:
gate_bias = fused_bias[..., ::2] # [E, I] - even indices
up_bias = fused_bias[..., 1::2] # [E, I] - odd indices
else:
ffn_dim = fused_bias.shape[-1] // 2
gate_bias, up_bias = fused_bias.split(ffn_dim, dim=-1)

experts.gate_proj_bias.data.copy_(gate_bias)
experts.up_proj_bias.data.copy_(up_bias)

# ---- update the state-dict so load_state_dict sees the right keys
sd[gate_key] = gate
sd[up_key] = up

if has_bias:
sd[gate_bias_key] = gate_bias
sd[up_bias_key] = up_bias

# Delete fused keys
if delete_fused_key:
del sd[fused_key]
if has_bias:
del sd[fused_bias_key]

logger.info(f"[layer {layer_idx:02d}] loaded gate_proj & up_proj from fused tensor (shape {fused.shape})")
logger.info(f"[layer {layer_idx:02d}] loaded gate_proj & up_proj from fused tensor (shape {fused.shape})")
transformed = True

if hasattr(model, "language_model"):
model.language_model = model_tmp
else:
model = model_tmp

return model, transformed


VLM_SPLIT_GATE_UP_WEIGHTS = {"QEffLlama4ForConditionalGeneration", "QEffLlama4ForCausalLM"}
# Keep the existing list of supported models
VLM_SPLIT_GATE_UP_WEIGHTS = {"QEffLlama4ForConditionalGeneration", "QEffLlama4ForCausalLM", "QEffGptOssForCausalLM"}
116 changes: 116 additions & 0 deletions QEfficient/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,3 +537,119 @@ def update(
ctx_v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out)
v_out = torch.where((is_sliding_layer & (position_ids.max() >= (layer_ctx_len - 1))), v_out, ctx_v_out)
return k_out, v_out


# This is a hack for now, until we get to merging this code with HybridCache class,
# We don't really need to inherit transformers classes as their cache classes are made to work with pytorch and
# ours are made to work with AIC
class QEffHybridCacheForGPTOSS:
def __init__(self, config, batch_size, max_cache_len, sliding_window_len):
self.max_cache_len = max_cache_len
self.batch_size = batch_size
self.sliding_window_len = sliding_window_len
self.key_cache: List[torch.Tensor] = []
self.value_cache: List[torch.Tensor] = []

@classmethod
def from_legacy_cache(
cls, config, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
) -> "HybridCache":
"""Converts a cache in the legacy cache format into an equivalent `DynamicCache`. Used for
backward compatibility."""
cache = cls(
config,
batch_size=past_key_values[0][0].shape[0],
max_cache_len=past_key_values[1][0].shape[2],
sliding_window_len=past_key_values[0][0].shape[2],
)
if past_key_values is not None:
for layer_idx in range(len(past_key_values)):
key_states, value_states = past_key_values[layer_idx]
cache.update(key_states, value_states, layer_idx)
return cache

def __len__(self):
"""
Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds
to the number of layers in the model.
"""
return len(self.key_cache)

def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
# TODO: deprecate this function in favor of `cache_position`
is_empty_layer = (
len(self.key_cache) == 0 # no cache in any layer
or len(self.key_cache) <= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it
or len(self.key_cache[layer_idx]) == 0 # the layer has no cache
)
layer_seq_length = self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0
return layer_seq_length

def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
"""Converts the `DynamicCache` instance into the its equivalent in the legacy cache format. Used for
backward compatibility."""
legacy_cache = ()
for layer_idx in range(len(self)):
legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),)
return legacy_cache

def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
if len(self.key_cache) <= layer_idx:
self.key_cache.append(key_states)
self.value_cache.append(value_states)
k_out, v_out = key_states, value_states
else:
position_ids = cache_kwargs.get("position_ids")
is_sliding_layer = cache_kwargs.get("is_sliding")
sliding_window = cache_kwargs.get("sliding_window")
batch_index = cache_kwargs.get("batch_index", None) # Check and fetch batch index value from the kwargs

if is_sliding_layer:
kv_position_ids = torch.where(position_ids == -1, position_ids, position_ids % sliding_window)
else:
kv_position_ids = position_ids

if batch_index is not None:
invalid_scatter_index = torch.iinfo(torch.int32).max
scatter_position_ids = torch.where(kv_position_ids < 0, invalid_scatter_index, kv_position_ids)
self.key_cache[layer_idx] = CtxScatterFuncCB.apply(
self.key_cache[layer_idx], batch_index, scatter_position_ids, key_states
)
self.value_cache[layer_idx] = CtxScatterFuncCB.apply(
self.value_cache[layer_idx], batch_index, scatter_position_ids, value_states
)
else:
self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], kv_position_ids, key_states)
self.value_cache[layer_idx] = CtxScatterFunc.apply(
self.value_cache[layer_idx], kv_position_ids, value_states
)

k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx]

# Original Gather
ctx_len = self.key_cache[layer_idx].shape[2]
ctx_indices = torch.arange(ctx_len)[None, None, ...]
gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1)
invalid_mask = ctx_indices > gather_limit
if torch.onnx.is_in_onnx_export():
invalid_idx_value = torch.iinfo(torch.int32).max
else:
invalid_idx_value = 0
ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices)

if batch_index is not None:
k_out = CtxGatherFuncCB.apply(k_out, batch_index, ctx_indices)
v_out = CtxGatherFuncCB.apply(v_out, batch_index, ctx_indices)
else:
k_out = CtxGatherFunc.apply(k_out, ctx_indices)
v_out = CtxGatherFunc.apply(v_out, ctx_indices)

v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out)
return k_out, v_out
1 change: 1 addition & 0 deletions QEfficient/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@
]
)

# This is for supporting different seq_len for different layers for Sliding window attn, chunked attn etc.
DYNAMIC_SEQ_LEN_SUPPORTED_MODEL_ARCH = {"gemma3", "llama4", "gemma3_text", "llama4_text"}

# Define a transformers layers to QEff layers dictionary
Expand Down
6 changes: 6 additions & 0 deletions QEfficient/transformers/models/gpt_oss/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# -----------------------------------------------------------------------------
#
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
# SPDX-License-Identifier: BSD-3-Clause
#
# ----------------------------------------------------------------------------
Loading
Loading