11#pragma once
22
3+ #include < libtorchaudio/shim_temporary.h>
34#include < torch/csrc/stable/c/shim.h>
45#include < torch/csrc/stable/device.h>
56
@@ -20,24 +21,24 @@ inline cudaStream_t getCurrentCUDAStream(
2021inline void setCurrentCUDAStream (
2122 cudaStream_t stream,
2223 torch::stable::DeviceIndex device_index = -1 ) {
23- TORCH_ERROR_CODE_CHECK (
24- torch_set_current_cuda_stream ( static_cast <void *>(stream), device_index));
24+ TORCH_ERROR_CODE_CHECK (tmp_torch_set_current_cuda_stream (
25+ static_cast <void *>(stream), device_index));
2526}
2627
2728inline cudaStream_t getStreamFromPool (
2829 const bool isHighPriority = false ,
2930 torch::stable::DeviceIndex device_index = -1 ) {
3031 void * stream_ptr = nullptr ;
31- TORCH_ERROR_CODE_CHECK (torch_get_cuda_stream_from_pool (
32+ TORCH_ERROR_CODE_CHECK (tmp_torch_get_cuda_stream_from_pool (
3233 isHighPriority, device_index, &stream_ptr));
3334 return static_cast <cudaStream_t>(stream_ptr);
3435}
3536
3637inline void synchronize (
3738 cudaStream_t stream,
3839 torch::stable::DeviceIndex device_index = -1 ) {
39- TORCH_ERROR_CODE_CHECK (
40- torch_cuda_stream_synchronize ( static_cast <void *>(stream), device_index));
40+ TORCH_ERROR_CODE_CHECK (tmp_torch_cuda_stream_synchronize (
41+ static_cast <void *>(stream), device_index));
4142}
4243
4344} // namespace libtorchaudio::cuda
0 commit comments