Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 15 additions & 21 deletions paddleformers/transformers/deepseek_v2/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,9 @@
from ...nn.norm import RMSNorm
from ...nn.pp_model import EmbeddingPipe, GeneralModelForCausalLMPipe, parse_args
from ...utils.log import logger
from ...utils.masking_utils import (
_expand_2d_mask,
_make_causal_mask,
get_use_casual_mask,
is_casual_mask,
)
from ...utils.masking_utils import _expand_2d_mask, _make_causal_mask
from ..conversion_utils import StateDictNameMapping, init_name_mappings
from ..masking_utils import create_causal_masks_and_row_indices
from ..model_outputs import (
BaseModelOutputWithPastAndMTP,
CausalLMOutputWithPast,
Expand Down Expand Up @@ -1477,21 +1473,19 @@ def forward(
if position_embeddings is None:
position_embeddings = paddle.stack(self.rotary_emb(inputs_embeds, position_ids=position_ids))

# embed positions
if attn_mask_startend_row_indices is not None or get_use_casual_mask():
attention_mask = None
else:
# [bs, seq_len]
attention_mask = (
paddle.ones((batch_size, seq_length_with_past), dtype=paddle.bool)
if attention_mask is None
else attention_mask
)
attention_mask = self._prepare_decoder_attention_mask(
attention_mask, (batch_size, seq_length), past_key_values_length, inputs_embeds.dtype
) # [bs, 1, seq_len, seq_len]
if self.config.use_flash_attention:
attention_mask = None if is_casual_mask(attention_mask) else attention_mask
mask_kwargs = {
"config": self.config,
"inputs_embeds": inputs_embeds,
"batch_size": batch_size,
"seq_length": seq_length,
"cache_length": past_key_values_length,
"attention_mask": attention_mask,
"attn_mask_startend_row_indices": attn_mask_startend_row_indices,
"prepare_decoder_attention_mask": self._prepare_decoder_attention_mask,
"return_mapping": False,
}

attention_mask, attn_mask_startend_row_indices = create_causal_masks_and_row_indices(**mask_kwargs)

if self.config.num_nextn_predict_layers > 0:
inputs_embeds_extra = inputs_embeds[:, -self.config.num_nextn_predict_layers :, :] # [B, S, D]
Expand Down
20 changes: 14 additions & 6 deletions paddleformers/transformers/ernie4_5/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from ...nn.norm import Norm as GeneralNorm
from ...nn.pp_model import GeneralModelForCausalLMPipe
from ...utils.log import logger
from ..masking_utils import create_causal_masks_and_row_indices
from ..model_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
Expand Down Expand Up @@ -615,12 +616,19 @@ def forward(

hidden_states = inputs_embeds

if attention_mask is not None:
causal_attention_mask = self._prepare_decoder_attention_mask(
attention_mask, hidden_states.shape[:2], kv_seq_len, hidden_states.dtype
)
else:
causal_attention_mask = None
mask_kwargs = {
"config": self.config,
"inputs_embeds": inputs_embeds,
"batch_size": bsz,
"seq_length": seq_length,
"cache_length": kv_seq_len,
"attention_mask": attention_mask,
"attn_mask_startend_row_indices": attn_mask_startend_row_indices,
"prepare_decoder_attention_mask": self._prepare_decoder_attention_mask,
"return_mapping": False,
}

causal_attention_mask, attn_mask_startend_row_indices = create_causal_masks_and_row_indices(**mask_kwargs)

if position_ids is None:
position_ids = paddle.arange(kv_seq_len, seq_length).unsqueeze(0).tile((bsz, 1))
Expand Down
19 changes: 15 additions & 4 deletions paddleformers/transformers/ernie4_5_moe/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from ...nn.pp_model import GeneralModelForCausalLMPipe
from ...utils.log import logger
from ..ernie4_5.modeling import Ernie4_5Attention
from ..masking_utils import create_causal_masks_and_row_indices
from ..model_outputs import MoECausalLMOutputWithPast, MoECausalLMOutputWithPastAndMTP
from ..model_utils import PretrainedModel, register_base_model
from ..tensor_parallel_utils import model_parallel_dropout
Expand Down Expand Up @@ -800,6 +801,7 @@ def forward(
bsz, seq_length, _ = inputs_embeds.shape
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
full_seq_length = seq_length

if past_key_values is None:
past_key_values = tuple([None] * len(self.layers))
Expand All @@ -815,10 +817,19 @@ def forward(
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)

if attention_mask is not None:
attention_mask = self._prepare_decoder_attention_mask(
attention_mask, inputs_embeds.shape[:2], kv_seq_len, inputs_embeds.dtype
)
mask_kwargs = {
"config": self.config,
"inputs_embeds": inputs_embeds,
"batch_size": bsz,
"seq_length": full_seq_length,
"cache_length": kv_seq_len,
"attention_mask": attention_mask,
"attn_mask_startend_row_indices": attn_mask_startend_row_indices,
"prepare_decoder_attention_mask": self._prepare_decoder_attention_mask,
"return_mapping": False,
}

attention_mask, attn_mask_startend_row_indices = create_causal_masks_and_row_indices(**mask_kwargs)

if self.training and self.config.num_nextn_predict_layers > 0:
inputs_embeds_extra = inputs_embeds[:, -self.config.num_nextn_predict_layers :, :]
Expand Down
18 changes: 17 additions & 1 deletion paddleformers/transformers/masking_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,10 @@ def create_causal_masks_and_row_indices(
Start/end row indices mapping for full and sliding attention.
"""

has_sliding_layers = config.sliding_window is not None and "sliding_attention" in config.layer_types
sliding_window_val = getattr(config, "sliding_window", None)
layer_types_val = getattr(config, "layer_types", [])

has_sliding_layers = (sliding_window_val is not None) and ("sliding_attention" in layer_types_val)

if attn_mask_startend_row_indices is not None:
attention_mask = None
Expand All @@ -125,6 +128,19 @@ def create_causal_masks_and_row_indices(
)
return causal_mask, attn_mask_startend_row_indices

# Enables the efficient built-in causal mode (is_causal=True)
# for FA backends (sdpa/flashmask), bypassing manual mask generation.
FLASH_BACKENDS = {"sdpa", "flashmask"}
attn_impl = getattr(config, "_attn_implementation", "eager")
is_flash_backend = attn_impl in FLASH_BACKENDS
if attention_mask is None and attn_mask_startend_row_indices is None and is_flash_backend:
if return_mapping:
causal_mask_mapping = {"full_attention": None, "sliding_attention": None}
attn_mask_startend_row_indices_mapping = {"full_attention": None, "sliding_attention": None}
return causal_mask_mapping, attn_mask_startend_row_indices_mapping
else:
return None, None

seq_length_with_past = seq_length + cache_length
attention_mask = (
paddle.ones((batch_size, seq_length_with_past), dtype=paddle.bool)
Expand Down