@@ -28,7 +28,7 @@ def __init__(self, n_head, n_feat, dropout_rate):
28
28
torch .nn .init .xavier_uniform_ (self .pos_bias_v )
29
29
30
30
def rel_shift (self , x , left_context_size : int = 0 , right_context_size : int = 0 ):
31
- """Compute relative positional encoding. The position should capture both
31
+ """Compute relative positional encoding. The position should capture both
32
32
left and right context.
33
33
34
34
Args:
@@ -88,8 +88,8 @@ def forward(self, query: torch.Tensor,
88
88
q , k , v = self .forward_qkv (query , key , value )
89
89
q = q .transpose (1 , 2 ) # (batch, time1, head, d_k)
90
90
91
- limited_context_attn = (chunk_size > 0
92
- and left_context_size > 0
91
+ limited_context_attn = (chunk_size > 0
92
+ and left_context_size > 0
93
93
and right_context_size > 0 )
94
94
95
95
# NOTE(xcsong):
@@ -121,7 +121,7 @@ def forward(self, query: torch.Tensor,
121
121
# chunking query
122
122
# [B, time1, head, d_k]
123
123
q_size = q .size (1 )
124
- n_frames_pad = (chunk_size - ((q_size - chunk_size ) % chunk_size ))
124
+ n_frames_pad = (chunk_size - ((q_size - chunk_size ) % chunk_size ))
125
125
n_frames_pad = n_frames_pad % chunk_size
126
126
q = torch .nn .functional .pad (q , (0 , 0 , 0 , 0 , 0 , n_frames_pad ))
127
127
# [B, n_chunks, head, d_k, q_size]
@@ -135,12 +135,12 @@ def forward(self, query: torch.Tensor,
135
135
# (batch, head, time1, d_k * 2)
136
136
kv = torch .cat ([k , v ], dim = - 1 )
137
137
kv = torch .nn .functional .pad (
138
- kv ,
138
+ kv ,
139
139
(0 , 0 , left_context_size , n_frames_pad + right_context_size ))
140
140
# [B, head, n_chunks, d_k * 2, l + c + r]
141
141
kv = kv .unfold (
142
- 2 ,
143
- size = left_context_size + chunk_size + right_context_size ,
142
+ 2 ,
143
+ size = left_context_size + chunk_size + right_context_size ,
144
144
step = chunk_size )
145
145
# [B, n_chunks, head, l + c + r, d_k * 2]
146
146
kv = kv .permute (0 , 2 , 1 , 4 , 3 )
@@ -158,12 +158,12 @@ def forward(self, query: torch.Tensor,
158
158
159
159
# Chunking mask for key and value
160
160
mask_kv = torch .nn .functional .pad (
161
- mask ,
161
+ mask ,
162
162
(left_context_size , n_frames_pad + right_context_size ))
163
163
# [B, 1, n_chunks, chunk_size]
164
164
mask_kv = mask_kv .unfold (
165
- - 1 ,
166
- size = left_context_size + chunk_size + right_context_size ,
165
+ - 1 ,
166
+ size = left_context_size + chunk_size + right_context_size ,
167
167
step = chunk_size )
168
168
# [B, * n_chunks, chunk_size]
169
169
mask_kv = mask_kv .reshape (- 1 , mask_kv .size (3 ))
0 commit comments