When using unet as an argument to pre_post_transformer_enc_dec, the function maybe_transition_to_modality_decoding parse out the modality_shape,it should be multiplied by 2:
modality_shape = tuple(map(lambda x: x * 2,modality_shape))
the shape flow:
modality_shape[7,7]->nosie[7,7,4]->denoised[7,7,4]->latent modality_tensor[7,7,4]->model modality_tensor[4,4,dim]