@@ -67,6 +67,15 @@ def has_int_squareroot(num):
6767
6868# tensor helpers
6969
70+ def pad_or_curtail_to_length (t , length ):
71+ if t .shape [- 1 ] == length :
72+ return t
73+
74+ if t .shape [- 1 ] > length :
75+ return t [..., :length ]
76+
77+ return F .pad (t , (0 , length - t .shape [- 1 ]))
78+
7079def prob_mask_like (shape , prob , device ):
7180 if prob == 1 :
7281 return torch .ones (shape , device = device , dtype = torch .bool )
@@ -834,6 +843,7 @@ def __init__(
834843 )
835844
836845 # prompt condition
846+
837847 self .cond_drop_prob = cond_drop_prob # for classifier free guidance
838848 self .condition_on_prompt = condition_on_prompt
839849 self .to_prompt_cond = None
@@ -861,6 +871,15 @@ def __init__(
861871 use_flash_attn = use_flash_attn
862872 )
863873
874+ # aligned conditioning from aligner + duration module
875+
876+ self .null_cond = None
877+ self .cond_to_model_dim = None
878+
879+ if self .condition_on_prompt :
880+ self .cond_to_model_dim = nn .Conv1d (dim_prompt , dim , 1 )
881+ self .null_cond = nn .Parameter (torch .zeros (dim , 1 ))
882+
864883 # conditioning includes time and optionally prompt
865884
866885 dim_cond_mult = dim_cond_mult * (2 if condition_on_prompt else 1 )
@@ -913,23 +932,27 @@ def forward(
913932 times ,
914933 prompt = None ,
915934 prompt_mask = None ,
916- cond = None ,
935+ cond = None ,
917936 cond_drop_prob = None
918937 ):
919938 b = x .shape [0 ]
920939 cond_drop_prob = default (cond_drop_prob , self .cond_drop_prob )
921940
922- drop_mask = prob_mask_like ((b ,), cond_drop_prob , self .device )
941+ # prepare prompt condition
942+ # prob should remove going forward
923943
924944 t = self .to_time_cond (times )
925945 c = None
926946
927947 if exists (self .to_prompt_cond ):
928948 assert exists (prompt )
949+
950+ prompt_cond_drop_mask = prob_mask_like ((b ,), cond_drop_prob , self .device )
951+
929952 prompt_cond = self .to_prompt_cond (prompt )
930953
931954 prompt_cond = torch .where (
932- rearrange (drop_mask , 'b -> b 1' ),
955+ rearrange (prompt_cond_drop_mask , 'b -> b 1' ),
933956 self .null_prompt_cond ,
934957 prompt_cond ,
935958 )
@@ -939,12 +962,37 @@ def forward(
939962 resampled_prompt_tokens = self .perceiver_resampler (prompt , mask = prompt_mask )
940963
941964 c = torch .where (
942- rearrange (drop_mask , 'b -> b 1 1' ),
965+ rearrange (prompt_cond_drop_mask , 'b -> b 1 1' ),
943966 self .null_prompt_tokens ,
944967 resampled_prompt_tokens
945968 )
946969
970+ # rearrange to channel first
971+
947972 x = rearrange (x , 'b n d -> b d n' )
973+
974+ # sum aligned condition to input sequence
975+
976+ if exists (self .cond_to_model_dim ):
977+ assert exists (cond )
978+ cond = self .cond_to_model_dim (cond )
979+
980+ cond_drop_mask = prob_mask_like ((b ,), cond_drop_prob , self .device )
981+
982+ cond = torch .where (
983+ rearrange (cond_drop_mask , 'b -> b 1 1' ),
984+ self .null_cond ,
985+ cond
986+ )
987+
988+ # for now, conform the condition to the length of the latent features
989+
990+ cond = pad_or_curtail_to_length (cond , x .shape [- 1 ])
991+
992+ x = x + cond
993+
994+ # main wavenet body
995+
948996 x = self .wavenet (x , t )
949997 x = rearrange (x , 'b d n -> b n d' )
950998
@@ -1527,6 +1575,7 @@ def forward(
15271575 duration_pred , pitch_pred = self .duration_pitch (phoneme_enc , prompt_enc )
15281576
15291577 pitch = average_over_durations (pitch , aln_hard )
1578+
15301579 cond = self .expand_encodings (rearrange (phoneme_enc , 'b n d -> b d n' ), rearrange (aln_mask , 'b n c -> b 1 n c' ), pitch )
15311580
15321581 # pitch and duration loss
@@ -1536,6 +1585,7 @@ def forward(
15361585 pitch = rearrange (pitch , 'b 1 d -> b d' )
15371586 pitch_loss = F .l1_loss (pitch , pitch_pred )
15381587 align_loss = self .aligner_loss (aln_log , text_lens , mel_lens )
1588+
15391589 # weigh the losses
15401590
15411591 aux_loss = (duration_loss * self .duration_loss_weight ) \
0 commit comments