@@ -9,8 +9,9 @@ use crate::sys as cuda;
99use bytemuck:: { Pod , Zeroable } ;
1010use std:: fmt:: { self , Debug , Formatter } ;
1111use std:: marker:: PhantomData ;
12- use std:: mem:: { self , size_of} ;
13- use std:: ops:: { Range , RangeFrom , RangeFull , RangeInclusive , RangeTo , RangeToInclusive } ;
12+ use std:: ops:: {
13+ Index , IndexMut , Range , RangeFrom , RangeFull , RangeInclusive , RangeTo , RangeToInclusive ,
14+ } ;
1415use std:: os:: raw:: c_void;
1516use std:: ptr:: { slice_from_raw_parts, slice_from_raw_parts_mut} ;
1617
@@ -236,16 +237,13 @@ impl<T: DeviceCopy + Pod> DeviceSlice<T> {
236237 /// In total it will set `sizeof<T> * len` values of `value` contiguously.
237238 #[ cfg_attr( docsrs, doc( cfg( feature = "bytemuck" ) ) ) ]
238239 pub fn set_8 ( & mut self , value : u8 ) -> CudaResult < ( ) > {
240+ if self . size_in_bytes ( ) == 0 {
241+ return Ok ( ( ) ) ;
242+ }
243+
239244 // SAFETY: We know T can hold any value because it is `Pod`, and
240245 // sub-byte alignment isn't a thing so we know the alignment is right.
241- unsafe {
242- cuda:: cuMemsetD8_v2 (
243- self . as_device_ptr ( ) . as_raw ( ) ,
244- value,
245- size_of :: < T > ( ) * self . len ( ) ,
246- )
247- . to_result ( )
248- }
246+ unsafe { cuda:: cuMemsetD8_v2 ( self . as_raw_ptr ( ) , value, self . size_in_bytes ( ) ) . to_result ( ) }
249247 }
250248
251249 /// Sets the memory range of this buffer to contiguous `8-bit` values of `value` asynchronously.
@@ -258,10 +256,14 @@ impl<T: DeviceCopy + Pod> DeviceSlice<T> {
258256 /// Therefore you should not read/write from/to the memory range until the operation is complete.
259257 #[ cfg_attr( docsrs, doc( cfg( feature = "bytemuck" ) ) ) ]
260258 pub unsafe fn set_8_async ( & mut self , value : u8 , stream : & Stream ) -> CudaResult < ( ) > {
259+ if self . size_in_bytes ( ) == 0 {
260+ return Ok ( ( ) ) ;
261+ }
262+
261263 cuda:: cuMemsetD8Async (
262- self . as_device_ptr ( ) . as_raw ( ) ,
264+ self . as_raw_ptr ( ) ,
263265 value,
264- size_of :: < T > ( ) * self . len ( ) ,
266+ self . size_in_bytes ( ) ,
265267 stream. as_inner ( ) ,
266268 )
267269 . to_result ( )
@@ -279,20 +281,18 @@ impl<T: DeviceCopy + Pod> DeviceSlice<T> {
279281 #[ track_caller]
280282 #[ cfg_attr( docsrs, doc( cfg( feature = "bytemuck" ) ) ) ]
281283 pub fn set_16 ( & mut self , value : u16 ) -> CudaResult < ( ) > {
282- let data_len = size_of :: < T > ( ) * self . len ( ) ;
284+ let data_len = self . size_in_bytes ( ) ;
283285 assert_eq ! (
284286 data_len % 2 ,
285287 0 ,
286288 "Buffer length is not a multiple of 2 bytes!"
287289 ) ;
288290 assert_eq ! (
289- self . as_device_ptr ( ) . as_raw ( ) % 2 ,
291+ self . as_raw_ptr ( ) % 2 ,
290292 0 ,
291293 "Buffer pointer is not aligned to at least 2 bytes!"
292294 ) ;
293- unsafe {
294- cuda:: cuMemsetD16_v2 ( self . as_device_ptr ( ) . as_raw ( ) , value, data_len / 2 ) . to_result ( )
295- }
295+ unsafe { cuda:: cuMemsetD16_v2 ( self . as_raw_ptr ( ) , value, data_len / 2 ) . to_result ( ) }
296296 }
297297
298298 /// Sets the memory range of this buffer to contiguous `16-bit` values of `value` asynchronously.
@@ -312,24 +312,19 @@ impl<T: DeviceCopy + Pod> DeviceSlice<T> {
312312 #[ track_caller]
313313 #[ cfg_attr( docsrs, doc( cfg( feature = "bytemuck" ) ) ) ]
314314 pub unsafe fn set_16_async ( & mut self , value : u16 , stream : & Stream ) -> CudaResult < ( ) > {
315- let data_len = size_of :: < T > ( ) * self . len ( ) ;
315+ let data_len = self . size_in_bytes ( ) ;
316316 assert_eq ! (
317317 data_len % 2 ,
318318 0 ,
319319 "Buffer length is not a multiple of 2 bytes!"
320320 ) ;
321321 assert_eq ! (
322- self . as_device_ptr ( ) . as_raw ( ) % 2 ,
322+ self . as_raw_ptr ( ) % 2 ,
323323 0 ,
324324 "Buffer pointer is not aligned to at least 2 bytes!"
325325 ) ;
326- cuda:: cuMemsetD16Async (
327- self . as_device_ptr ( ) . as_raw ( ) ,
328- value,
329- data_len / 2 ,
330- stream. as_inner ( ) ,
331- )
332- . to_result ( )
326+ cuda:: cuMemsetD16Async ( self . as_raw_ptr ( ) , value, data_len / 2 , stream. as_inner ( ) )
327+ . to_result ( )
333328 }
334329
335330 /// Sets the memory range of this buffer to contiguous `32-bit` values of `value`.
@@ -344,20 +339,18 @@ impl<T: DeviceCopy + Pod> DeviceSlice<T> {
344339 #[ track_caller]
345340 #[ cfg_attr( docsrs, doc( cfg( feature = "bytemuck" ) ) ) ]
346341 pub fn set_32 ( & mut self , value : u32 ) -> CudaResult < ( ) > {
347- let data_len = size_of :: < T > ( ) * self . len ( ) ;
342+ let data_len = self . size_in_bytes ( ) ;
348343 assert_eq ! (
349344 data_len % 4 ,
350345 0 ,
351346 "Buffer length is not a multiple of 4 bytes!"
352347 ) ;
353348 assert_eq ! (
354- self . as_device_ptr ( ) . as_raw ( ) % 4 ,
349+ self . as_raw_ptr ( ) % 4 ,
355350 0 ,
356351 "Buffer pointer is not aligned to at least 4 bytes!"
357352 ) ;
358- unsafe {
359- cuda:: cuMemsetD32_v2 ( self . as_device_ptr ( ) . as_raw ( ) , value, data_len / 4 ) . to_result ( )
360- }
353+ unsafe { cuda:: cuMemsetD32_v2 ( self . as_raw_ptr ( ) , value, data_len / 4 ) . to_result ( ) }
361354 }
362355
363356 /// Sets the memory range of this buffer to contiguous `32-bit` values of `value` asynchronously.
@@ -377,24 +370,19 @@ impl<T: DeviceCopy + Pod> DeviceSlice<T> {
377370 #[ track_caller]
378371 #[ cfg_attr( docsrs, doc( cfg( feature = "bytemuck" ) ) ) ]
379372 pub unsafe fn set_32_async ( & mut self , value : u32 , stream : & Stream ) -> CudaResult < ( ) > {
380- let data_len = size_of :: < T > ( ) * self . len ( ) ;
373+ let data_len = self . size_in_bytes ( ) ;
381374 assert_eq ! (
382375 data_len % 4 ,
383376 0 ,
384377 "Buffer length is not a multiple of 4 bytes!"
385378 ) ;
386379 assert_eq ! (
387- self . as_device_ptr ( ) . as_raw ( ) % 4 ,
380+ self . as_raw_ptr ( ) % 4 ,
388381 0 ,
389382 "Buffer pointer is not aligned to at least 4 bytes!"
390383 ) ;
391- cuda:: cuMemsetD32Async (
392- self . as_device_ptr ( ) . as_raw ( ) ,
393- value,
394- data_len / 4 ,
395- stream. as_inner ( ) ,
396- )
397- . to_result ( )
384+ cuda:: cuMemsetD32Async ( self . as_raw_ptr ( ) , value, data_len / 4 , stream. as_inner ( ) )
385+ . to_result ( )
398386 }
399387}
400388
@@ -405,10 +393,7 @@ impl<T: DeviceCopy + Zeroable> DeviceSlice<T> {
405393 // SAFETY: this is fine because Zeroable guarantees a zero byte-pattern is safe
406394 // for this type. And a slice of bytes can represent any type.
407395 let erased = unsafe {
408- DeviceSlice :: from_raw_parts_mut (
409- self . as_device_ptr ( ) . cast :: < u8 > ( ) ,
410- size_of :: < T > ( ) * self . len ( ) ,
411- )
396+ DeviceSlice :: from_raw_parts_mut ( self . as_device_ptr ( ) . cast :: < u8 > ( ) , self . size_in_bytes ( ) )
412397 } ;
413398 erased. set_8 ( 0 )
414399 }
@@ -420,14 +405,11 @@ impl<T: DeviceCopy + Zeroable> DeviceSlice<T> {
420405 /// This operation is async so it does not complete immediately, it uses stream-ordering semantics.
421406 /// Therefore you should not read/write from/to the memory range until the operation is complete.
422407 pub unsafe fn set_zero_async ( & mut self , stream : & Stream ) -> CudaResult < ( ) > {
423- if self . as_device_ptr ( ) . is_null ( ) {
424- return Ok ( ( ) ) ;
425- }
426408 // SAFETY: this is fine because Zeroable guarantees a zero byte-pattern is safe
427409 // for this type. And a slice of bytes can represent any type.
428410 let erased = DeviceSlice :: from_raw_parts_mut (
429411 self . as_device_ptr ( ) . cast :: < u8 > ( ) ,
430- size_of :: < T > ( ) * self . len ( ) ,
412+ self . size_in_bytes ( ) ,
431413 ) ;
432414 erased. set_8_async ( 0 , stream)
433415 }
@@ -636,13 +618,17 @@ impl<T: DeviceCopy> DeviceSliceIndex<T> for RangeToInclusive<usize> {
636618 }
637619}
638620
639- impl < T : DeviceCopy > DeviceSlice < T > {
640- pub fn index < Idx : DeviceSliceIndex < T > > ( & self , idx : Idx ) -> & DeviceSlice < T > {
641- idx. index ( self )
621+ impl < T : DeviceCopy , Idx : DeviceSliceIndex < T > > Index < Idx > for DeviceSlice < T > {
622+ type Output = DeviceSlice < T > ;
623+
624+ fn index ( & self , index : Idx ) -> & DeviceSlice < T > {
625+ index. index ( self )
642626 }
627+ }
643628
644- pub fn index_mut < Idx : DeviceSliceIndex < T > > ( & mut self , idx : Idx ) -> & mut DeviceSlice < T > {
645- idx. index_mut ( self )
629+ impl < T : DeviceCopy , Idx : DeviceSliceIndex < T > > IndexMut < Idx > for DeviceSlice < T > {
630+ fn index_mut ( & mut self , index : Idx ) -> & mut DeviceSlice < T > {
631+ index. index_mut ( self )
646632 }
647633}
648634
@@ -654,15 +640,11 @@ impl<T: DeviceCopy, I: AsRef<[T]> + AsMut<[T]> + ?Sized> CopyDestination<I> for
654640 self . len( ) == val. len( ) ,
655641 "destination and source slices have different lengths"
656642 ) ;
657- let size = mem :: size_of :: < T > ( ) * self . len ( ) ;
643+ let size = self . size_in_bytes ( ) ;
658644 if size != 0 {
659645 unsafe {
660- cuda:: cuMemcpyHtoD_v2 (
661- self . as_device_ptr ( ) . as_raw ( ) ,
662- val. as_ptr ( ) as * const c_void ,
663- size,
664- )
665- . to_result ( ) ?
646+ cuda:: cuMemcpyHtoD_v2 ( self . as_raw_ptr ( ) , val. as_ptr ( ) as * const c_void , size)
647+ . to_result ( ) ?
666648 }
667649 }
668650 Ok ( ( ) )
@@ -674,15 +656,11 @@ impl<T: DeviceCopy, I: AsRef<[T]> + AsMut<[T]> + ?Sized> CopyDestination<I> for
674656 self . len( ) == val. len( ) ,
675657 "destination and source slices have different lengths"
676658 ) ;
677- let size = mem :: size_of :: < T > ( ) * self . len ( ) ;
659+ let size = self . size_in_bytes ( ) ;
678660 if size != 0 {
679661 unsafe {
680- cuda:: cuMemcpyDtoH_v2 (
681- val. as_mut_ptr ( ) as * mut c_void ,
682- self . as_device_ptr ( ) . as_raw ( ) ,
683- size,
684- )
685- . to_result ( ) ?
662+ cuda:: cuMemcpyDtoH_v2 ( val. as_mut_ptr ( ) as * mut c_void , self . as_raw_ptr ( ) , size)
663+ . to_result ( ) ?
686664 }
687665 }
688666 Ok ( ( ) )
@@ -694,16 +672,9 @@ impl<T: DeviceCopy> CopyDestination<DeviceSlice<T>> for DeviceSlice<T> {
694672 self . len( ) == val. len( ) ,
695673 "destination and source slices have different lengths"
696674 ) ;
697- let size = mem :: size_of :: < T > ( ) * self . len ( ) ;
675+ let size = self . size_in_bytes ( ) ;
698676 if size != 0 {
699- unsafe {
700- cuda:: cuMemcpyDtoD_v2 (
701- self . as_device_ptr ( ) . as_raw ( ) ,
702- val. as_device_ptr ( ) . as_raw ( ) ,
703- size,
704- )
705- . to_result ( ) ?
706- }
677+ unsafe { cuda:: cuMemcpyDtoD_v2 ( self . as_raw_ptr ( ) , val. as_raw_ptr ( ) , size) . to_result ( ) ? }
707678 }
708679 Ok ( ( ) )
709680 }
@@ -713,16 +684,9 @@ impl<T: DeviceCopy> CopyDestination<DeviceSlice<T>> for DeviceSlice<T> {
713684 self . len( ) == val. len( ) ,
714685 "destination and source slices have different lengths"
715686 ) ;
716- let size = mem :: size_of :: < T > ( ) * self . len ( ) ;
687+ let size = self . size_in_bytes ( ) ;
717688 if size != 0 {
718- unsafe {
719- cuda:: cuMemcpyDtoD_v2 (
720- val. as_device_ptr ( ) . as_raw ( ) ,
721- self . as_device_ptr ( ) . as_raw ( ) ,
722- size,
723- )
724- . to_result ( ) ?
725- }
689+ unsafe { cuda:: cuMemcpyDtoD_v2 ( val. as_raw_ptr ( ) , self . as_raw_ptr ( ) , size) . to_result ( ) ? }
726690 }
727691 Ok ( ( ) )
728692 }
@@ -745,10 +709,10 @@ impl<T: DeviceCopy, I: AsRef<[T]> + AsMut<[T]> + ?Sized> AsyncCopyDestination<I>
745709 self . len( ) == val. len( ) ,
746710 "destination and source slices have different lengths"
747711 ) ;
748- let size = mem :: size_of :: < T > ( ) * self . len ( ) ;
712+ let size = self . size_in_bytes ( ) ;
749713 if size != 0 {
750714 cuda:: cuMemcpyHtoDAsync_v2 (
751- self . as_device_ptr ( ) . as_raw ( ) ,
715+ self . as_raw_ptr ( ) ,
752716 val. as_ptr ( ) as * const c_void ,
753717 size,
754718 stream. as_inner ( ) ,
@@ -764,11 +728,11 @@ impl<T: DeviceCopy, I: AsRef<[T]> + AsMut<[T]> + ?Sized> AsyncCopyDestination<I>
764728 self . len( ) == val. len( ) ,
765729 "destination and source slices have different lengths"
766730 ) ;
767- let size = mem :: size_of :: < T > ( ) * self . len ( ) ;
731+ let size = self . size_in_bytes ( ) ;
768732 if size != 0 {
769733 cuda:: cuMemcpyDtoHAsync_v2 (
770734 val. as_mut_ptr ( ) as * mut c_void ,
771- self . as_device_ptr ( ) . as_raw ( ) ,
735+ self . as_raw_ptr ( ) ,
772736 size,
773737 stream. as_inner ( ) ,
774738 )
@@ -783,15 +747,10 @@ impl<T: DeviceCopy> AsyncCopyDestination<DeviceSlice<T>> for DeviceSlice<T> {
783747 self . len( ) == val. len( ) ,
784748 "destination and source slices have different lengths"
785749 ) ;
786- let size = mem :: size_of :: < T > ( ) * self . len ( ) ;
750+ let size = self . size_in_bytes ( ) ;
787751 if size != 0 {
788- cuda:: cuMemcpyDtoDAsync_v2 (
789- self . as_device_ptr ( ) . as_raw ( ) ,
790- val. as_device_ptr ( ) . as_raw ( ) ,
791- size,
792- stream. as_inner ( ) ,
793- )
794- . to_result ( ) ?
752+ cuda:: cuMemcpyDtoDAsync_v2 ( self . as_raw_ptr ( ) , val. as_raw_ptr ( ) , size, stream. as_inner ( ) )
753+ . to_result ( ) ?
795754 }
796755 Ok ( ( ) )
797756 }
@@ -801,15 +760,10 @@ impl<T: DeviceCopy> AsyncCopyDestination<DeviceSlice<T>> for DeviceSlice<T> {
801760 self . len( ) == val. len( ) ,
802761 "destination and source slices have different lengths"
803762 ) ;
804- let size = mem :: size_of :: < T > ( ) * self . len ( ) ;
763+ let size = self . size_in_bytes ( ) ;
805764 if size != 0 {
806- cuda:: cuMemcpyDtoDAsync_v2 (
807- val. as_device_ptr ( ) . as_raw ( ) ,
808- self . as_device_ptr ( ) . as_raw ( ) ,
809- size,
810- stream. as_inner ( ) ,
811- )
812- . to_result ( ) ?
765+ cuda:: cuMemcpyDtoDAsync_v2 ( val. as_raw_ptr ( ) , self . as_raw_ptr ( ) , size, stream. as_inner ( ) )
766+ . to_result ( ) ?
813767 }
814768 Ok ( ( ) )
815769 }
0 commit comments