Skip to content

[Cpp API Compatibility] Sync c10 CUDA stream state with Paddle's GPUContext stream#78652

Open
youge325 wants to merge 4 commits intoPaddlePaddle:developfrom
youge325:cSyncStream
Open

[Cpp API Compatibility] Sync c10 CUDA stream state with Paddle's GPUContext stream#78652
youge325 wants to merge 4 commits intoPaddlePaddle:developfrom
youge325:cSyncStream

Conversation

@youge325
Copy link
Copy Markdown
Contributor

PR Category

Execute Infrastructure

PR Types

Bug fixes

Description

修复 PaddlePaddle/FastDeploy#7344 中提到的问题

  • getCurrentCUDAStream() 优先与 Paddle 当前流保持一致,从 GPUContext 读取当前 stream,而不是只依赖 compat 自己的 TLS。
  • setCurrentCUDAStream() 在更新 compat 状态时,同时把当前流同步回 GPUContext
  • 使用 TLS 持有的 phi::CUDAStream wrapper 回写 GPUContext,避免直接篡改外部 stream 对象。
  • 补充回归测试,覆盖:
    • setCurrentCUDAStream() 后 compat/c10 与 Paddle 看到同一条流
    • 仅 Paddle 侧切流时,getCurrentCUDAStream() 也能返回正确流

是否引起精度变化

Copilot AI review requested due to automatic review settings April 12, 2026 06:30
@paddle-bot
Copy link
Copy Markdown

paddle-bot bot commented Apr 12, 2026

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@paddle-bot paddle-bot bot added the contributor External developers label Apr 12, 2026
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR updates the c10 CUDA stream compatibility layer to keep c10’s “current stream” in sync with Paddle’s phi::GPUContext stream, addressing stream mismatches reported by downstream users (e.g., FastDeploy).

Changes:

  • Make c10::cuda::getCurrentCUDAStream() prefer Paddle’s current GPUContext stream instead of compat-only TLS state.
  • Make c10::cuda::setCurrentCUDAStream() also write the selected stream back into Paddle’s GPUContext.
  • Add regression tests validating bidirectional stream synchronization behavior.

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 3 comments.

File Description
test/cpp/compat/c10_Stream_test.cc Adds tests to ensure c10 stream APIs observe the same stream as Paddle’s GPUContext, including when Paddle changes the stream directly.
paddle/phi/api/include/compat/c10/cuda/CUDAStream.h Updates API documentation for setCurrentCUDAStream() to describe new alignment intent.
paddle/phi/api/include/compat/c10/cuda/CUDAStream.cpp Implements reading current stream from Paddle and syncing setCurrentCUDAStream() into Paddle GPUContext.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +236 to +237
getMutableGPUContext(idx)->SetCUDAStream(current_stream.get(),
/*clear=*/false);
Copy link

Copilot AI Apr 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

setCurrentCUDAStream() stores a thread_local phi::CUDAStream (tls_current_streams[idx]) into the global DeviceContextPool GPUContext via SetCUDAStream(..., clear=false). This can (1) leak the previously-owned GPUContext stream object (because clear=false skips deleting an owned stream_ before overwriting the pointer) and (2) leave GPUContext holding a dangling pointer when the calling thread exits and its thread_local unique_ptr is destroyed. Prefer updating GPUContext via SetStream(stream.stream()) (keeps ownership/lifetime inside GPUContext) or otherwise ensure the CUDAStream object you pass to SetCUDAStream has a lifetime at least as long as the GPUContext and that the old owned stream is not leaked.

Suggested change
getMutableGPUContext(idx)->SetCUDAStream(current_stream.get(),
/*clear=*/false);
getMutableGPUContext(idx)->SetStream(stream.stream());

Copilot uses AI. Check for mistakes.
Comment on lines +196 to +197
* Keeps the compat c10 stream state aligned with Paddle's GPUContext so
* Paddle stream guards and c10 callers observe the same current stream.
Copy link

Copilot AI Apr 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The updated docstring removes the previous “per-thread” guarantee, but it still doesn’t clearly state the new semantics. The implementation updates Paddle’s global GPUContext in DeviceContextPool, so this can affect other threads on the same device; please document whether setCurrentCUDAStream is intended to be process-wide/per-device (and any thread-safety expectations) to avoid callers assuming PyTorch-like per-thread behavior.

Suggested change
* Keeps the compat c10 stream state aligned with Paddle's GPUContext so
* Paddle stream guards and c10 callers observe the same current stream.
* This updates Paddle's current stream state through the shared GPUContext
* stored in DeviceContextPool for the target device so Paddle stream guards
* and c10 callers observe the same current stream.
*
* Semantics: this is not a PyTorch-style per-thread current-stream setting.
* The change is effectively process-wide for the given device because other
* threads using the same device may observe the updated current stream.
*
* Thread-safety: callers must not assume thread-local isolation. Concurrent
* calls that change the current stream for the same device can affect one
* another, so external synchronization may be required.

Copilot uses AI. Check for mistakes.
#include <vector>

#include "paddle/phi/api/include/context_pool.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
Copy link

Copilot AI Apr 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CUDAStream.cpp now stores/constructs phi::CUDAStream objects (std::unique_ptrphi::CUDAStream and std::make_uniquephi::CUDAStream(...)), but this TU only includes context_pool.h and gpu_context.h which forward-declare phi::CUDAStream; it does not include the definition from paddle/phi/core/cuda_stream.h. This will fail to compile due to incomplete type usage (make_unique / unique_ptr destructor). Add the proper include for the phi::CUDAStream definition.

Suggested change
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/cuda_stream.h"

Copilot uses AI. Check for mistakes.
@youge325
Copy link
Copy Markdown
Contributor Author

/re-run all-failed

@lizexu123
Copy link
Copy Markdown
Contributor

我将本地试一下,辛苦

@SigureMo
Copy link
Copy Markdown
Member

明确说明这里为什么可以接受不再保持 per-thread 语义,以及它和 torch / c10 预期的关系;

鉴于我们目前基本没有多线程使用场景,只有多进程使用场景,因此我觉得这不是阻塞项

之前 Paddle / PyTorch 兼容层各自维护一套是有很多风险的,PFCCLab/DeepEP#11 (review) 有说过啊,@ShigureNyako 忘了吗

SigureMo
SigureMo previously approved these changes Apr 12, 2026
Copy link
Copy Markdown
Member

@SigureMo SigureMo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTMeow 🐾

@codecov-commenter
Copy link
Copy Markdown

codecov-commenter commented Apr 12, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
⚠️ Please upload report for BASE (develop@49365ae). Learn more about missing BASE report.

Additional details and impacted files
@@             Coverage Diff             @@
##             develop    #78652   +/-   ##
===========================================
  Coverage           ?   100.00%           
===========================================
  Files              ?         1           
  Lines              ?        10           
  Branches           ?         0           
===========================================
  Hits               ?        10           
  Misses             ?         0           
  Partials           ?         0           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@lizexu123
Copy link
Copy Markdown
Contributor

多线程好像还不行,可以写个单测验证吗

@SigureMo
Copy link
Copy Markdown
Member

多线程好像还不行,可以写个单测验证吗

给个最小复现样例

@youge325
Copy link
Copy Markdown
Contributor Author

多线程好像还不行,可以写个单测验证吗

我在本地写了一个单测,发现确实不行,因为 setCurrentCUDAStream 会把流写回全局 GPUContext,getCurrentCUDAStream 又从这个全局对象读,导致无法线程隔离

TEST(CUDAStreamTest, CurrentStreamIsThreadLocal) {
  if (!at::cuda::is_available()) {
    return;
  }

  auto main_original_stream = c10::cuda::getCurrentCUDAStream();
  auto device_index = main_original_stream.device_index();
  auto thread_stream_a = c10::cuda::getStreamFromPool(
      /*isHighPriority=*/false, device_index);
  auto thread_stream_b = c10::cuda::getStreamFromPool(
      /*isHighPriority=*/true, device_index);
  ASSERT_NE(thread_stream_a, thread_stream_b);

  std::promise<void> thread_a_set_promise;
  std::shared_future<void> thread_a_set_future =
      thread_a_set_promise.get_future().share();
  std::promise<void> thread_b_set_promise;
  std::shared_future<void> thread_b_set_future =
      thread_b_set_promise.get_future().share();

  c10::cuda::CUDAStream observed_stream_a = main_original_stream;
  c10::cuda::CUDAStream observed_stream_b = main_original_stream;
  c10::cuda::CUDAStream main_observed_stream = main_original_stream;

  std::mutex error_mutex;
  std::exception_ptr thread_error = nullptr;
  auto record_thread_error = [&](std::exception_ptr e) {
    std::lock_guard<std::mutex> guard(error_mutex);
    if (!thread_error) {
      thread_error = e;
    }
  };

  std::thread thread_a([&] {
    try {
      c10::cuda::setCurrentCUDAStream(thread_stream_a);
      thread_a_set_promise.set_value();
      thread_b_set_future.wait();
      observed_stream_a = c10::cuda::getCurrentCUDAStream(device_index);
    } catch (...) {
      try {
        thread_a_set_promise.set_value();
      } catch (...) {
      }
      try {
        thread_b_set_promise.set_value();
      } catch (...) {
      }
      record_thread_error(std::current_exception());
    }
  });

  std::thread thread_b([&] {
    try {
      thread_a_set_future.wait();
      c10::cuda::setCurrentCUDAStream(thread_stream_b);
      observed_stream_b = c10::cuda::getCurrentCUDAStream(device_index);
      thread_b_set_promise.set_value();
    } catch (...) {
      try {
        thread_b_set_promise.set_value();
      } catch (...) {
      }
      record_thread_error(std::current_exception());
    }
  });

  thread_b_set_future.wait();
  main_observed_stream = c10::cuda::getCurrentCUDAStream(device_index);
  thread_a.join();
  thread_b.join();
  c10::cuda::setCurrentCUDAStream(main_original_stream);

  if (thread_error) {
    try {
      std::rethrow_exception(thread_error);
    } catch (const std::exception& e) {
      FAIL() << "Unexpected exception in worker thread: " << e.what();
    } catch (...) {
      FAIL() << "Unexpected unknown exception in worker thread.";
    }
  }

  EXPECT_EQ(observed_stream_a, thread_stream_a);
  EXPECT_EQ(observed_stream_b, thread_stream_b);
  EXPECT_EQ(main_observed_stream, main_original_stream);
}
[ RUN      ] CUDAStreamTest.CurrentStreamIsThreadLocal
/home/may/Paddle/test/cpp/compat/c10_Stream_test.cc:380: Failure
Expected equality of these values:
  observed_stream_a
    Which is: stream 98955431905504 on device cuda:0
  thread_stream_a
    Which is: stream 98955431906080 on device cuda:0
/home/may/Paddle/test/cpp/compat/c10_Stream_test.cc:382: Failure
Expected equality of these values:
  main_observed_stream
    Which is: stream 98955431905504 on device cuda:0
  main_original_stream
    Which is: stream 98955431858256 on device cuda:0
[  FAILED  ] CUDAStreamTest.CurrentStreamIsThreadLocal (0 ms)

等我看看怎么修比较好

@youge325
Copy link
Copy Markdown
Contributor Author

现在还有用到多线程的场景吗,目前兼容层获取流的各类方法都没用用到 TLS 了,如果不需要我就先删了

@lizexu123
Copy link
Copy Markdown
Contributor

现在还有用到多线程的场景吗,目前兼容层获取流的各类方法都没用用到 TLS 了,如果不需要我就先删了

等我确认下,你可以先修复着

@SigureMo
Copy link
Copy Markdown
Member

现在还有用到多线程的场景吗,目前兼容层获取流的各类方法都没用用到 TLS 了,如果不需要我就先删了

他的问题不是多线程场景,是多进程场景,先不需要考虑多线程场景,可以清理

@youge325
Copy link
Copy Markdown
Contributor Author

/re-run all-failed

@youge325
Copy link
Copy Markdown
Contributor Author

/re-run all-failed

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

contributor External developers

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants