Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 22 additions & 20 deletions paddle/phi/api/include/compat/c10/cuda/CUDAStream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
#include <mutex>
#include <vector>

#include "paddle/phi/api/include/context_pool.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
Copy link

Copilot AI Apr 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CUDAStream.cpp now stores/constructs phi::CUDAStream objects (std::unique_ptrphi::CUDAStream and std::make_uniquephi::CUDAStream(...)), but this TU only includes context_pool.h and gpu_context.h which forward-declare phi::CUDAStream; it does not include the definition from paddle/phi/core/cuda_stream.h. This will fail to compile due to incomplete type usage (make_unique / unique_ptr destructor). Add the proper include for the phi::CUDAStream definition.

Suggested change
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/cuda_stream.h"

Copilot uses AI. Check for mistakes.
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#include "paddle/phi/backends/gpu/gpu_info.h"
#endif
Expand Down Expand Up @@ -49,13 +51,6 @@ struct DevicePools {

std::vector<std::unique_ptr<DevicePools>> g_pools;

#ifdef PADDLE_WITH_HIP
thread_local std::vector<hipStream_t> tls_current_streams;
#else
thread_local std::vector<cudaStream_t> tls_current_streams;
#endif
thread_local bool tls_streams_initialized = false;

void initGlobalState() {
std::call_once(g_init_once, []() {
g_num_gpus =
Expand Down Expand Up @@ -104,12 +99,25 @@ inline void check_gpu(c10::DeviceIndex device_index) {
")");
}

inline void initTLSCurrentStreams() {
if (!tls_streams_initialized) {
tls_current_streams.resize(g_num_gpus, nullptr);
tls_streams_initialized = true;
}
inline phi::GPUContext* getMutableGPUContext(c10::DeviceIndex device_index) {
return static_cast<phi::GPUContext*>(
paddle::experimental::DeviceContextPool::Instance().GetMutable(
phi::GPUPlace(device_index)));
}

#ifdef PADDLE_WITH_HIP
inline hipStream_t getPaddleCurrentStream(c10::DeviceIndex device_index) {
auto* current_stream =
paddle::GetCurrentCUDAStream(phi::GPUPlace(device_index));
return current_stream == nullptr ? nullptr : current_stream->raw_stream();
}
#else
inline cudaStream_t getPaddleCurrentStream(c10::DeviceIndex device_index) {
auto* current_stream =
paddle::GetCurrentCUDAStream(phi::GPUPlace(device_index));
return current_stream == nullptr ? nullptr : current_stream->raw_stream();
}
#endif

#endif // defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)

Expand Down Expand Up @@ -192,12 +200,7 @@ CUDAStream getCurrentCUDAStream(c10::DeviceIndex device_index) {
static_cast<c10::DeviceIndex>(phi::backends::gpu::GetCurrentDeviceId());
}
check_gpu(device_index);
initTLSCurrentStreams();
#ifdef PADDLE_WITH_HIP
hipStream_t raw = tls_current_streams[device_index];
#else
cudaStream_t raw = tls_current_streams[device_index];
#endif
auto raw = getPaddleCurrentStream(device_index);
if (raw == nullptr) {
return getDefaultCUDAStream(device_index);
}
Expand All @@ -212,8 +215,7 @@ void setCurrentCUDAStream(CUDAStream stream) {
initGlobalState();
c10::DeviceIndex idx = stream.unwrap().device_index();
check_gpu(idx);
initTLSCurrentStreams();
tls_current_streams[idx] = stream.stream();
getMutableGPUContext(idx)->SetStream(stream.stream());
#else
(void)stream;
#endif
Expand Down
6 changes: 3 additions & 3 deletions paddle/phi/api/include/compat/c10/cuda/CUDAStream.h
Original file line number Diff line number Diff line change
Expand Up @@ -191,10 +191,10 @@ CUDAStream getStreamFromExternal(cudaStream_t ext_stream,
#endif

/**
* Set the current CUDA stream for the device of the given stream in the
* calling thread.
* Set the current CUDA stream for the device of the given stream.
*
* Implements per-thread, per-device current stream semantics.
* Keeps the compat c10 stream state aligned with Paddle's GPUContext so
* Paddle stream guards and c10 callers observe the same current stream.
Comment on lines +196 to +197
Copy link

Copilot AI Apr 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The updated docstring removes the previous “per-thread” guarantee, but it still doesn’t clearly state the new semantics. The implementation updates Paddle’s global GPUContext in DeviceContextPool, so this can affect other threads on the same device; please document whether setCurrentCUDAStream is intended to be process-wide/per-device (and any thread-safety expectations) to avoid callers assuming PyTorch-like per-thread behavior.

Suggested change
* Keeps the compat c10 stream state aligned with Paddle's GPUContext so
* Paddle stream guards and c10 callers observe the same current stream.
* This updates Paddle's current stream state through the shared GPUContext
* stored in DeviceContextPool for the target device so Paddle stream guards
* and c10 callers observe the same current stream.
*
* Semantics: this is not a PyTorch-style per-thread current-stream setting.
* The change is effectively process-wide for the given device because other
* threads using the same device may observe the updated current stream.
*
* Thread-safety: callers must not assume thread-local isolation. Concurrent
* calls that change the current stream for the same device can affect one
* another, so external synchronization may be required.

Copilot uses AI. Check for mistakes.
*/
void setCurrentCUDAStream(CUDAStream stream);

Expand Down
35 changes: 12 additions & 23 deletions test/cpp/compat/c10_Stream_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
#include <thread>

#include "gtest/gtest.h"
#include "paddle/phi/api/include/context_pool.h"
#include "paddle/phi/backends/gpu/gpu_context.h"

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
namespace {
Expand Down Expand Up @@ -167,25 +169,6 @@ TEST(StreamTest, QueryCudaStreamNotReadyReturnsFalse) {
EXPECT_NO_THROW(s.synchronize());
}

TEST(StreamTest, QueryCudaStreamInvalidHandleThrows) {
if (!at::cuda::is_available()) {
return;
}

auto device_index = c10::cuda::getCurrentCUDAStream().device_index();
#ifdef PADDLE_WITH_HIP
hipStream_t raw_stream = nullptr;
#else
cudaStream_t raw_stream = nullptr;
#endif
ASSERT_NO_THROW(CreateRawStream(&raw_stream));

auto cuda_stream = c10::cuda::getStreamFromExternal(raw_stream, device_index);
ASSERT_NO_THROW(DestroyRawStream(raw_stream));

EXPECT_THROW(cuda_stream.query(), std::exception);
ClearLastStreamError();
}
#endif // PADDLE_WITH_CUDA || PADDLE_WITH_HIP

// ==================== synchronize ====================
Expand Down Expand Up @@ -255,30 +238,36 @@ TEST(CUDAStreamTest, GetStreamFromPoolBoolOverloadPreservesHighPriority) {
EXPECT_NE(high_priority, low_priority);
}

// After setCurrentCUDAStream redirects the per-thread current stream,
// After setCurrentCUDAStream redirects the current stream,
// getDefaultCUDAStream must still return the null stream.
TEST(CUDAStreamTest, DefaultStreamUnaffectedBySetCurrentCUDAStream) {
if (!at::cuda::is_available()) {
return;
}
// Snapshot the per-thread current stream before we touch it so we can
// Snapshot the current stream before we touch it so we can
// restore it afterward and avoid polluting subsequent tests.
auto original_stream = c10::cuda::getCurrentCUDAStream();

// Obtain a non-default stream from the pool.
auto pool_stream = c10::cuda::getStreamFromPool(/*isHighPriority=*/false);

// Redirect the per-thread current stream.
// Redirect the current stream.
c10::cuda::setCurrentCUDAStream(pool_stream);

auto default_stream = c10::cuda::getDefaultCUDAStream();
auto current_stream = c10::cuda::getCurrentCUDAStream();
auto place = phi::GPUPlace(current_stream.device_index());

// Default stream is still null; current stream has changed.
EXPECT_EQ(default_stream.id(), static_cast<c10::StreamId>(0));
EXPECT_NE(default_stream, current_stream);
EXPECT_EQ(paddle::GetCurrentCUDAStream(place)->raw_stream(),
current_stream.stream());

// Restore the original per-thread current stream.
// Restore the original current stream.
c10::cuda::setCurrentCUDAStream(original_stream);
EXPECT_EQ(paddle::GetCurrentCUDAStream(place)->raw_stream(),
original_stream.stream());
}

#endif // PADDLE_WITH_CUDA || PADDLE_WITH_HIP
Loading