Skip to content

Commit d48bda5

Browse files
authored
[flang][cuda] Handle zero sized allocation correctly (llvm#160929)
Like on the host allocate 1 byte when zero size is requested.
1 parent 24bc1a6 commit d48bda5

File tree

2 files changed

+18
-13
lines changed

2 files changed

+18
-13
lines changed

flang-rt/lib/cuda/memory.cpp

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -25,23 +25,22 @@ extern "C" {
2525
void *RTDEF(CUFMemAlloc)(
2626
std::size_t bytes, unsigned type, const char *sourceFile, int sourceLine) {
2727
void *ptr = nullptr;
28-
if (bytes != 0) {
29-
if (type == kMemTypeDevice) {
30-
if (Fortran::runtime::executionEnvironment.cudaDeviceIsManaged) {
31-
CUDA_REPORT_IF_ERROR(
32-
cudaMallocManaged((void **)&ptr, bytes, cudaMemAttachGlobal));
33-
} else {
34-
CUDA_REPORT_IF_ERROR(cudaMalloc((void **)&ptr, bytes));
35-
}
36-
} else if (type == kMemTypeManaged || type == kMemTypeUnified) {
28+
bytes = bytes ? bytes : 1;
29+
if (type == kMemTypeDevice) {
30+
if (Fortran::runtime::executionEnvironment.cudaDeviceIsManaged) {
3731
CUDA_REPORT_IF_ERROR(
3832
cudaMallocManaged((void **)&ptr, bytes, cudaMemAttachGlobal));
39-
} else if (type == kMemTypePinned) {
40-
CUDA_REPORT_IF_ERROR(cudaMallocHost((void **)&ptr, bytes));
4133
} else {
42-
Terminator terminator{sourceFile, sourceLine};
43-
terminator.Crash("unsupported memory type");
34+
CUDA_REPORT_IF_ERROR(cudaMalloc((void **)&ptr, bytes));
4435
}
36+
} else if (type == kMemTypeManaged || type == kMemTypeUnified) {
37+
CUDA_REPORT_IF_ERROR(
38+
cudaMallocManaged((void **)&ptr, bytes, cudaMemAttachGlobal));
39+
} else if (type == kMemTypePinned) {
40+
CUDA_REPORT_IF_ERROR(cudaMallocHost((void **)&ptr, bytes));
41+
} else {
42+
Terminator terminator{sourceFile, sourceLine};
43+
terminator.Crash("unsupported memory type");
4544
}
4645
return ptr;
4746
}

flang-rt/unittests/Runtime/CUDA/Memory.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,12 @@ TEST(MemoryCUFTest, SimpleAllocTramsferFree) {
3535
RTNAME(CUFMemFree)((void *)dev, kMemTypeDevice, __FILE__, __LINE__);
3636
}
3737

38+
TEST(MemoryCUFTest, AllocZero) {
39+
int *dev = (int *)RTNAME(CUFMemAlloc)(0, kMemTypeDevice, __FILE__, __LINE__);
40+
EXPECT_TRUE(dev != 0);
41+
RTNAME(CUFMemFree)((void *)dev, kMemTypeDevice, __FILE__, __LINE__);
42+
}
43+
3844
static OwningPtr<Descriptor> createAllocatable(
3945
Fortran::common::TypeCategory tc, int kind, int rank = 1) {
4046
return Descriptor::Create(TypeCode{tc, kind}, kind, nullptr, rank, nullptr,

0 commit comments

Comments
 (0)