1+ #include < libtorchaudio/cuda_utils.h>
12#include < libtorchaudio/utils.h>
23#include < torch/csrc/stable/library.h>
34#include < torch/headeronly/core/Dispatch_v2.h>
45#include < torch/headeronly/core/ScalarType.h>
5- #include < ATen /cuda/CUDAContext .h>
6+ #include < c10 /cuda/CUDAException .h>
67
78#include < cub/cub.cuh>
89#include < limits.h>
@@ -120,8 +121,9 @@ void forced_align_impl(
120121 const Tensor& targets,
121122 const int64_t blank,
122123 Tensor& paths) {
123- auto defaultStream = at::cuda::getCurrentCUDAStream ();
124- auto cpuDataTranferStream = at::cuda::getStreamFromPool ();
124+ auto device_index = logProbs.get_device_index ();
125+ auto defaultStream = libtorchaudio::cuda::getCurrentCUDAStream (device_index);
126+ auto cpuDataTranferStream = libtorchaudio::cuda::getStreamFromPool (false , device_index);
125127 const scalar_t kNegInfinity = -std::numeric_limits<scalar_t >::infinity ();
126128 using target_t = typename std::
127129 conditional<target_scalar_type == ScalarType::Int, int , int64_t >::type;
@@ -208,12 +210,14 @@ void forced_align_impl(
208210 C10_CUDA_KERNEL_LAUNCH_CHECK ();
209211 ++backPtrBufferLen;
210212 if (backPtrBufferLen == kBackPtrBufferSize || t == T - 1 ) {
211- cpuDataTranferStream.synchronize ();
213+ // cpuDataTranferStream.synchronize();
214+ libtorchaudio::cuda::synchronize (cpuDataTranferStream, device_index);
212215 // GPU -> GPU copy
213216 bufferCopy = torch::stable::clone (backPtrBuffer);
214217 STD_TORCH_CHECK (bufferCopy.is_contiguous (), " unexpected fail, need to implement stable::Tensor::contiguous()" )
215- defaultStream.synchronize ();
216- at::cuda::setCurrentCUDAStream (cpuDataTranferStream);
218+ // defaultStream.synchronize();
219+ libtorchaudio::cuda::synchronize (defaultStream, device_index);
220+ libtorchaudio::cuda::setCurrentCUDAStream (cpuDataTranferStream, device_index);
217221 // Copy ASYNC from GPU to CPU
218222 int64_t offset =
219223 static_cast <int64_t >(t + 1 - backPtrBufferLen) * S * sizeof (int8_t );
@@ -223,11 +227,12 @@ void forced_align_impl(
223227 backPtrBufferLen * S * sizeof (int8_t ),
224228 cudaMemcpyDeviceToHost,
225229 cpuDataTranferStream));
226- at ::cuda::setCurrentCUDAStream (defaultStream);
230+ libtorchaudio ::cuda::setCurrentCUDAStream (defaultStream, device_index );
227231 backPtrBufferLen = 0 ;
228232 }
229233 }
230- cpuDataTranferStream.synchronize ();
234+ // cpuDataTranferStream.synchronize();
235+ libtorchaudio::cuda::synchronize (cpuDataTranferStream, device_index);
231236 auto alphasCpu = torchaudio::stable::cpu (alphas);
232237 auto alphasCpu_a = torchaudio::accessor<scalar_t , 2 >(alphasCpu);
233238 int curIdxOffset = ((T - 1 ) % 2 );
0 commit comments