Skip to content
4 changes: 3 additions & 1 deletion DeBERTa/deberta/disentangled_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,9 @@ def linear(w,b,x):
if self.talking_head:
attention_scores = self.head_logits_proj(attention_scores.permute(0,2,3,1)).permute(0,3,1,2)

attention_probs = XSoftmax.apply(attention_scores, attention_mask, -1)
#attention_probs = XSoftmax.apply(attention_scores, attention_mask, -1)
nodex = torch.nn.Softmax(-1)
attention_probs = nodex(attention_scores)
attention_probs = self.dropout(attention_probs)
if self.talking_head:
attention_probs = self.head_weights_proj(attention_probs.permute(0,2,3,1)).permute(0,3,1,2)
Expand Down
6 changes: 5 additions & 1 deletion DeBERTa/deberta/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,11 @@ def backward(ctx, grad_output):
else:
return grad_output, None

class StableDropout(torch.nn.Module):
class StableDropout(torch.nn.Dropout):
def __init__(self, drop_prob):
super().__init__()

class StableDropout1(torch.nn.Module):
""" Optimized dropout module for stabilizing the training

Args:
Expand Down