1313use crate :: error:: { CudaResult , DropResult , ToResult } ;
1414use crate :: event:: Event ;
1515use crate :: function:: { BlockSize , Function , GridSize } ;
16- use crate :: sys:: { self as cuda, CUstream } ;
16+ use crate :: sys:: { self as cuda, cudaError_enum , CUstream } ;
1717use std:: ffi:: c_void;
1818use std:: mem;
1919use std:: panic;
@@ -151,6 +151,9 @@ impl Stream {
151151 ///
152152 /// Callbacks must not make any CUDA API calls.
153153 ///
154+ /// The callback will be passed a `CudaResult<()>` indicating the
155+ /// current state of the device with `Ok(())` denoting normal operation.
156+ ///
154157 /// # Examples
155158 ///
156159 /// ```
@@ -164,22 +167,23 @@ impl Stream {
164167 ///
165168 /// // ... queue up some work on the stream
166169 ///
167- /// stream.add_callback(Box::new(|| {
168- /// println!("Work is done!" );
170+ /// stream.add_callback(Box::new(|status | {
171+ /// println!("Device status is {:?}", status );
169172 /// }));
170173 ///
171174 /// // ... queue up some more work on the stream
172175 /// # Ok(())
173176 /// # }
174177 pub fn add_callback < T > ( & self , callback : Box < T > ) -> CudaResult < ( ) >
175178 where
176- T : FnOnce ( ) + Send ,
179+ T : FnOnce ( CudaResult < ( ) > ) + Send ,
177180 {
178181 unsafe {
179- cuda:: cuLaunchHostFunc (
182+ cuda:: cuStreamAddCallback (
180183 self . inner ,
181184 Some ( callback_wrapper :: < T > ) ,
182185 Box :: into_raw ( callback) as * mut c_void ,
186+ 0 ,
183187 )
184188 . to_result ( )
185189 }
@@ -339,13 +343,16 @@ impl Drop for Stream {
339343 }
340344 }
341345}
342- unsafe extern "C" fn callback_wrapper < T > ( callback : * mut c_void )
343- where
344- T : FnOnce ( ) + Send ,
346+ unsafe extern "C" fn callback_wrapper < T > (
347+ _stream : CUstream ,
348+ status : cudaError_enum ,
349+ callback : * mut c_void ,
350+ ) where
351+ T : FnOnce ( CudaResult < ( ) > ) + Send ,
345352{
346353 // Stop panics from unwinding across the FFI
347354 let _ = panic:: catch_unwind ( || {
348355 let callback: Box < T > = Box :: from_raw ( callback as * mut T ) ;
349- callback ( ) ;
356+ callback ( status . to_result ( ) ) ;
350357 } ) ;
351358}
0 commit comments