@@ -45,15 +45,16 @@ at::Tensor BatchPrefillWithKVCachePlan(
4545 at::Tensor page_locked_int_workspace_buffer, at::Tensor qo_indptr, at::Tensor kv_indptr,
4646 at::Tensor kv_len_arr, int64_t total_num_rows, int64_t batch_size, int64_t num_qo_heads,
4747 int64_t num_kv_heads, int64_t page_size, bool enable_cuda_graph, int64_t head_dim_qk,
48- int64_t head_dim_vo, bool causal, int64_t cuda_stream ) {
48+ int64_t head_dim_vo, bool causal) {
4949 size_t float_workspace_size_in_bytes =
5050 float_workspace_buffer.size (0 ) * float_workspace_buffer.element_size ();
5151 size_t int_workspace_size_in_bytes =
5252 int_workspace_buffer.size (0 ) * int_workspace_buffer.element_size ();
5353
5454 PrefillPlanInfo plan_info;
5555
56- cudaStream_t stream = reinterpret_cast <cudaStream_t>(cuda_stream);
56+ const c10::cuda::OptionalCUDAGuard device_guard (float_workspace_buffer.device ());
57+ const cudaStream_t stream = c10::cuda::getCurrentCUDAStream ();
5758 cudaError_t status = PrefillPlan<IdType>(
5859 float_workspace_buffer.data_ptr (), float_workspace_size_in_bytes,
5960 int_workspace_buffer.data_ptr (), page_locked_int_workspace_buffer.data_ptr (),
@@ -72,8 +73,7 @@ void BatchPrefillWithRaggedKVCacheRun(at::Tensor float_workspace_buffer,
7273 at::Tensor q, at::Tensor k, at::Tensor v,
7374 at::Tensor qo_indptr, at::Tensor kv_indptr, at::Tensor o,
7475 std::optional<at::Tensor> maybe_lse, int64_t mask_mode_code,
75- int64_t layout, int64_t window_left ADDITIONAL_FUNC_PARAMS,
76- int64_t cuda_stream) {
76+ int64_t layout, int64_t window_left ADDITIONAL_FUNC_PARAMS) {
7777 PrefillPlanInfo plan_info;
7878 plan_info.FromVector (tensor_to_vec (plan_info_vec));
7979 QKVLayout kv_layout = static_cast <QKVLayout>(layout);
@@ -109,7 +109,8 @@ void BatchPrefillWithRaggedKVCacheRun(at::Tensor float_workspace_buffer,
109109 auto q_scalar_type = q.scalar_type ();
110110 auto kv_scalar_type = k.scalar_type ();
111111
112- cudaStream_t stream = reinterpret_cast <cudaStream_t>(cuda_stream);
112+ const c10::cuda::OptionalCUDAGuard device_guard (float_workspace_buffer.device ());
113+ const cudaStream_t stream = c10::cuda::getCurrentCUDAStream ();
113114
114115 DISPATCH_context (
115116 DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_QK, HEAD_DIM_VO, POS_ENCODING_MODE,
@@ -193,12 +194,14 @@ void BatchPrefillWithRaggedKVCacheRun(at::Tensor float_workspace_buffer,
193194 });
194195}
195196
196- void BatchPrefillWithPagedKVCacheRun (
197- at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, at::Tensor plan_info_vec,
198- at::Tensor q, at::Tensor paged_k_cache, at::Tensor paged_v_cache, at::Tensor qo_indptr,
199- at::Tensor paged_kv_indptr, at::Tensor paged_kv_indices, at::Tensor paged_kv_last_page_len,
200- at::Tensor o, std::optional<at::Tensor> maybe_lse, int64_t mask_mode_code, int64_t layout,
201- int64_t window_left ADDITIONAL_FUNC_PARAMS, int64_t cuda_stream) {
197+ void BatchPrefillWithPagedKVCacheRun (at::Tensor float_workspace_buffer,
198+ at::Tensor int_workspace_buffer, at::Tensor plan_info_vec,
199+ at::Tensor q, at::Tensor paged_k_cache,
200+ at::Tensor paged_v_cache, at::Tensor qo_indptr,
201+ at::Tensor paged_kv_indptr, at::Tensor paged_kv_indices,
202+ at::Tensor paged_kv_last_page_len, at::Tensor o,
203+ std::optional<at::Tensor> maybe_lse, int64_t mask_mode_code,
204+ int64_t layout, int64_t window_left ADDITIONAL_FUNC_PARAMS) {
202205 PrefillPlanInfo plan_info;
203206 plan_info.FromVector (tensor_to_vec (plan_info_vec));
204207 QKVLayout kv_layout = static_cast <QKVLayout>(layout);
@@ -239,7 +242,8 @@ void BatchPrefillWithPagedKVCacheRun(
239242 TORCH_CHECK (k_strides == v_strides, " k/v strides must be identical" );
240243 kv_cache_strides = k_strides.data ();
241244
242- cudaStream_t stream = reinterpret_cast <cudaStream_t>(cuda_stream);
245+ const c10::cuda::OptionalCUDAGuard device_guard (float_workspace_buffer.device ());
246+ const cudaStream_t stream = c10::cuda::getCurrentCUDAStream ();
243247
244248 DISPATCH_context (
245249 DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_QK, HEAD_DIM_VO, POS_ENCODING_MODE,
0 commit comments