Skip to content

Commit f1242df

Browse files
committed
impl Index traits and add size checks for copying
1 parent f2811f3 commit f1242df

File tree

1 file changed

+59
-105
lines changed

1 file changed

+59
-105
lines changed

crates/cust/src/memory/device/device_slice.rs

Lines changed: 59 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,9 @@ use crate::sys as cuda;
99
use bytemuck::{Pod, Zeroable};
1010
use std::fmt::{self, Debug, Formatter};
1111
use std::marker::PhantomData;
12-
use std::mem::{self, size_of};
13-
use std::ops::{Range, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive};
12+
use std::ops::{
13+
Index, IndexMut, Range, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive,
14+
};
1415
use std::os::raw::c_void;
1516
use std::ptr::{slice_from_raw_parts, slice_from_raw_parts_mut};
1617

@@ -236,16 +237,13 @@ impl<T: DeviceCopy + Pod> DeviceSlice<T> {
236237
/// In total it will set `sizeof<T> * len` values of `value` contiguously.
237238
#[cfg_attr(docsrs, doc(cfg(feature = "bytemuck")))]
238239
pub fn set_8(&mut self, value: u8) -> CudaResult<()> {
240+
if self.size_in_bytes() == 0 {
241+
return Ok(());
242+
}
243+
239244
// SAFETY: We know T can hold any value because it is `Pod`, and
240245
// sub-byte alignment isn't a thing so we know the alignment is right.
241-
unsafe {
242-
cuda::cuMemsetD8_v2(
243-
self.as_device_ptr().as_raw(),
244-
value,
245-
size_of::<T>() * self.len(),
246-
)
247-
.to_result()
248-
}
246+
unsafe { cuda::cuMemsetD8_v2(self.as_raw_ptr(), value, self.size_in_bytes()).to_result() }
249247
}
250248

251249
/// Sets the memory range of this buffer to contiguous `8-bit` values of `value` asynchronously.
@@ -258,10 +256,14 @@ impl<T: DeviceCopy + Pod> DeviceSlice<T> {
258256
/// Therefore you should not read/write from/to the memory range until the operation is complete.
259257
#[cfg_attr(docsrs, doc(cfg(feature = "bytemuck")))]
260258
pub unsafe fn set_8_async(&mut self, value: u8, stream: &Stream) -> CudaResult<()> {
259+
if self.size_in_bytes() == 0 {
260+
return Ok(());
261+
}
262+
261263
cuda::cuMemsetD8Async(
262-
self.as_device_ptr().as_raw(),
264+
self.as_raw_ptr(),
263265
value,
264-
size_of::<T>() * self.len(),
266+
self.size_in_bytes(),
265267
stream.as_inner(),
266268
)
267269
.to_result()
@@ -279,20 +281,18 @@ impl<T: DeviceCopy + Pod> DeviceSlice<T> {
279281
#[track_caller]
280282
#[cfg_attr(docsrs, doc(cfg(feature = "bytemuck")))]
281283
pub fn set_16(&mut self, value: u16) -> CudaResult<()> {
282-
let data_len = size_of::<T>() * self.len();
284+
let data_len = self.size_in_bytes();
283285
assert_eq!(
284286
data_len % 2,
285287
0,
286288
"Buffer length is not a multiple of 2 bytes!"
287289
);
288290
assert_eq!(
289-
self.as_device_ptr().as_raw() % 2,
291+
self.as_raw_ptr() % 2,
290292
0,
291293
"Buffer pointer is not aligned to at least 2 bytes!"
292294
);
293-
unsafe {
294-
cuda::cuMemsetD16_v2(self.as_device_ptr().as_raw(), value, data_len / 2).to_result()
295-
}
295+
unsafe { cuda::cuMemsetD16_v2(self.as_raw_ptr(), value, data_len / 2).to_result() }
296296
}
297297

298298
/// Sets the memory range of this buffer to contiguous `16-bit` values of `value` asynchronously.
@@ -312,24 +312,19 @@ impl<T: DeviceCopy + Pod> DeviceSlice<T> {
312312
#[track_caller]
313313
#[cfg_attr(docsrs, doc(cfg(feature = "bytemuck")))]
314314
pub unsafe fn set_16_async(&mut self, value: u16, stream: &Stream) -> CudaResult<()> {
315-
let data_len = size_of::<T>() * self.len();
315+
let data_len = self.size_in_bytes();
316316
assert_eq!(
317317
data_len % 2,
318318
0,
319319
"Buffer length is not a multiple of 2 bytes!"
320320
);
321321
assert_eq!(
322-
self.as_device_ptr().as_raw() % 2,
322+
self.as_raw_ptr() % 2,
323323
0,
324324
"Buffer pointer is not aligned to at least 2 bytes!"
325325
);
326-
cuda::cuMemsetD16Async(
327-
self.as_device_ptr().as_raw(),
328-
value,
329-
data_len / 2,
330-
stream.as_inner(),
331-
)
332-
.to_result()
326+
cuda::cuMemsetD16Async(self.as_raw_ptr(), value, data_len / 2, stream.as_inner())
327+
.to_result()
333328
}
334329

335330
/// Sets the memory range of this buffer to contiguous `32-bit` values of `value`.
@@ -344,20 +339,18 @@ impl<T: DeviceCopy + Pod> DeviceSlice<T> {
344339
#[track_caller]
345340
#[cfg_attr(docsrs, doc(cfg(feature = "bytemuck")))]
346341
pub fn set_32(&mut self, value: u32) -> CudaResult<()> {
347-
let data_len = size_of::<T>() * self.len();
342+
let data_len = self.size_in_bytes();
348343
assert_eq!(
349344
data_len % 4,
350345
0,
351346
"Buffer length is not a multiple of 4 bytes!"
352347
);
353348
assert_eq!(
354-
self.as_device_ptr().as_raw() % 4,
349+
self.as_raw_ptr() % 4,
355350
0,
356351
"Buffer pointer is not aligned to at least 4 bytes!"
357352
);
358-
unsafe {
359-
cuda::cuMemsetD32_v2(self.as_device_ptr().as_raw(), value, data_len / 4).to_result()
360-
}
353+
unsafe { cuda::cuMemsetD32_v2(self.as_raw_ptr(), value, data_len / 4).to_result() }
361354
}
362355

363356
/// Sets the memory range of this buffer to contiguous `32-bit` values of `value` asynchronously.
@@ -377,24 +370,19 @@ impl<T: DeviceCopy + Pod> DeviceSlice<T> {
377370
#[track_caller]
378371
#[cfg_attr(docsrs, doc(cfg(feature = "bytemuck")))]
379372
pub unsafe fn set_32_async(&mut self, value: u32, stream: &Stream) -> CudaResult<()> {
380-
let data_len = size_of::<T>() * self.len();
373+
let data_len = self.size_in_bytes();
381374
assert_eq!(
382375
data_len % 4,
383376
0,
384377
"Buffer length is not a multiple of 4 bytes!"
385378
);
386379
assert_eq!(
387-
self.as_device_ptr().as_raw() % 4,
380+
self.as_raw_ptr() % 4,
388381
0,
389382
"Buffer pointer is not aligned to at least 4 bytes!"
390383
);
391-
cuda::cuMemsetD32Async(
392-
self.as_device_ptr().as_raw(),
393-
value,
394-
data_len / 4,
395-
stream.as_inner(),
396-
)
397-
.to_result()
384+
cuda::cuMemsetD32Async(self.as_raw_ptr(), value, data_len / 4, stream.as_inner())
385+
.to_result()
398386
}
399387
}
400388

@@ -405,10 +393,7 @@ impl<T: DeviceCopy + Zeroable> DeviceSlice<T> {
405393
// SAFETY: this is fine because Zeroable guarantees a zero byte-pattern is safe
406394
// for this type. And a slice of bytes can represent any type.
407395
let erased = unsafe {
408-
DeviceSlice::from_raw_parts_mut(
409-
self.as_device_ptr().cast::<u8>(),
410-
size_of::<T>() * self.len(),
411-
)
396+
DeviceSlice::from_raw_parts_mut(self.as_device_ptr().cast::<u8>(), self.size_in_bytes())
412397
};
413398
erased.set_8(0)
414399
}
@@ -420,14 +405,11 @@ impl<T: DeviceCopy + Zeroable> DeviceSlice<T> {
420405
/// This operation is async so it does not complete immediately, it uses stream-ordering semantics.
421406
/// Therefore you should not read/write from/to the memory range until the operation is complete.
422407
pub unsafe fn set_zero_async(&mut self, stream: &Stream) -> CudaResult<()> {
423-
if self.as_device_ptr().is_null() {
424-
return Ok(());
425-
}
426408
// SAFETY: this is fine because Zeroable guarantees a zero byte-pattern is safe
427409
// for this type. And a slice of bytes can represent any type.
428410
let erased = DeviceSlice::from_raw_parts_mut(
429411
self.as_device_ptr().cast::<u8>(),
430-
size_of::<T>() * self.len(),
412+
self.size_in_bytes(),
431413
);
432414
erased.set_8_async(0, stream)
433415
}
@@ -636,13 +618,17 @@ impl<T: DeviceCopy> DeviceSliceIndex<T> for RangeToInclusive<usize> {
636618
}
637619
}
638620

639-
impl<T: DeviceCopy> DeviceSlice<T> {
640-
pub fn index<Idx: DeviceSliceIndex<T>>(&self, idx: Idx) -> &DeviceSlice<T> {
641-
idx.index(self)
621+
impl<T: DeviceCopy, Idx: DeviceSliceIndex<T>> Index<Idx> for DeviceSlice<T> {
622+
type Output = DeviceSlice<T>;
623+
624+
fn index(&self, index: Idx) -> &DeviceSlice<T> {
625+
index.index(self)
642626
}
627+
}
643628

644-
pub fn index_mut<Idx: DeviceSliceIndex<T>>(&mut self, idx: Idx) -> &mut DeviceSlice<T> {
645-
idx.index_mut(self)
629+
impl<T: DeviceCopy, Idx: DeviceSliceIndex<T>> IndexMut<Idx> for DeviceSlice<T> {
630+
fn index_mut(&mut self, index: Idx) -> &mut DeviceSlice<T> {
631+
index.index_mut(self)
646632
}
647633
}
648634

@@ -654,15 +640,11 @@ impl<T: DeviceCopy, I: AsRef<[T]> + AsMut<[T]> + ?Sized> CopyDestination<I> for
654640
self.len() == val.len(),
655641
"destination and source slices have different lengths"
656642
);
657-
let size = mem::size_of::<T>() * self.len();
643+
let size = self.size_in_bytes();
658644
if size != 0 {
659645
unsafe {
660-
cuda::cuMemcpyHtoD_v2(
661-
self.as_device_ptr().as_raw(),
662-
val.as_ptr() as *const c_void,
663-
size,
664-
)
665-
.to_result()?
646+
cuda::cuMemcpyHtoD_v2(self.as_raw_ptr(), val.as_ptr() as *const c_void, size)
647+
.to_result()?
666648
}
667649
}
668650
Ok(())
@@ -674,15 +656,11 @@ impl<T: DeviceCopy, I: AsRef<[T]> + AsMut<[T]> + ?Sized> CopyDestination<I> for
674656
self.len() == val.len(),
675657
"destination and source slices have different lengths"
676658
);
677-
let size = mem::size_of::<T>() * self.len();
659+
let size = self.size_in_bytes();
678660
if size != 0 {
679661
unsafe {
680-
cuda::cuMemcpyDtoH_v2(
681-
val.as_mut_ptr() as *mut c_void,
682-
self.as_device_ptr().as_raw(),
683-
size,
684-
)
685-
.to_result()?
662+
cuda::cuMemcpyDtoH_v2(val.as_mut_ptr() as *mut c_void, self.as_raw_ptr(), size)
663+
.to_result()?
686664
}
687665
}
688666
Ok(())
@@ -694,16 +672,9 @@ impl<T: DeviceCopy> CopyDestination<DeviceSlice<T>> for DeviceSlice<T> {
694672
self.len() == val.len(),
695673
"destination and source slices have different lengths"
696674
);
697-
let size = mem::size_of::<T>() * self.len();
675+
let size = self.size_in_bytes();
698676
if size != 0 {
699-
unsafe {
700-
cuda::cuMemcpyDtoD_v2(
701-
self.as_device_ptr().as_raw(),
702-
val.as_device_ptr().as_raw(),
703-
size,
704-
)
705-
.to_result()?
706-
}
677+
unsafe { cuda::cuMemcpyDtoD_v2(self.as_raw_ptr(), val.as_raw_ptr(), size).to_result()? }
707678
}
708679
Ok(())
709680
}
@@ -713,16 +684,9 @@ impl<T: DeviceCopy> CopyDestination<DeviceSlice<T>> for DeviceSlice<T> {
713684
self.len() == val.len(),
714685
"destination and source slices have different lengths"
715686
);
716-
let size = mem::size_of::<T>() * self.len();
687+
let size = self.size_in_bytes();
717688
if size != 0 {
718-
unsafe {
719-
cuda::cuMemcpyDtoD_v2(
720-
val.as_device_ptr().as_raw(),
721-
self.as_device_ptr().as_raw(),
722-
size,
723-
)
724-
.to_result()?
725-
}
689+
unsafe { cuda::cuMemcpyDtoD_v2(val.as_raw_ptr(), self.as_raw_ptr(), size).to_result()? }
726690
}
727691
Ok(())
728692
}
@@ -745,10 +709,10 @@ impl<T: DeviceCopy, I: AsRef<[T]> + AsMut<[T]> + ?Sized> AsyncCopyDestination<I>
745709
self.len() == val.len(),
746710
"destination and source slices have different lengths"
747711
);
748-
let size = mem::size_of::<T>() * self.len();
712+
let size = self.size_in_bytes();
749713
if size != 0 {
750714
cuda::cuMemcpyHtoDAsync_v2(
751-
self.as_device_ptr().as_raw(),
715+
self.as_raw_ptr(),
752716
val.as_ptr() as *const c_void,
753717
size,
754718
stream.as_inner(),
@@ -764,11 +728,11 @@ impl<T: DeviceCopy, I: AsRef<[T]> + AsMut<[T]> + ?Sized> AsyncCopyDestination<I>
764728
self.len() == val.len(),
765729
"destination and source slices have different lengths"
766730
);
767-
let size = mem::size_of::<T>() * self.len();
731+
let size = self.size_in_bytes();
768732
if size != 0 {
769733
cuda::cuMemcpyDtoHAsync_v2(
770734
val.as_mut_ptr() as *mut c_void,
771-
self.as_device_ptr().as_raw(),
735+
self.as_raw_ptr(),
772736
size,
773737
stream.as_inner(),
774738
)
@@ -783,15 +747,10 @@ impl<T: DeviceCopy> AsyncCopyDestination<DeviceSlice<T>> for DeviceSlice<T> {
783747
self.len() == val.len(),
784748
"destination and source slices have different lengths"
785749
);
786-
let size = mem::size_of::<T>() * self.len();
750+
let size = self.size_in_bytes();
787751
if size != 0 {
788-
cuda::cuMemcpyDtoDAsync_v2(
789-
self.as_device_ptr().as_raw(),
790-
val.as_device_ptr().as_raw(),
791-
size,
792-
stream.as_inner(),
793-
)
794-
.to_result()?
752+
cuda::cuMemcpyDtoDAsync_v2(self.as_raw_ptr(), val.as_raw_ptr(), size, stream.as_inner())
753+
.to_result()?
795754
}
796755
Ok(())
797756
}
@@ -801,15 +760,10 @@ impl<T: DeviceCopy> AsyncCopyDestination<DeviceSlice<T>> for DeviceSlice<T> {
801760
self.len() == val.len(),
802761
"destination and source slices have different lengths"
803762
);
804-
let size = mem::size_of::<T>() * self.len();
763+
let size = self.size_in_bytes();
805764
if size != 0 {
806-
cuda::cuMemcpyDtoDAsync_v2(
807-
val.as_device_ptr().as_raw(),
808-
self.as_device_ptr().as_raw(),
809-
size,
810-
stream.as_inner(),
811-
)
812-
.to_result()?
765+
cuda::cuMemcpyDtoDAsync_v2(val.as_raw_ptr(), self.as_raw_ptr(), size, stream.as_inner())
766+
.to_result()?
813767
}
814768
Ok(())
815769
}

0 commit comments

Comments
 (0)