25
25
#include " Shared/Utils.h"
26
26
#include " omptarget.h"
27
27
28
+ #include " llvm/Support/Error.h"
29
+
30
+ namespace llvm {
31
+
28
32
// / Base class of per-device allocator.
29
33
class DeviceAllocatorTy {
30
34
public:
31
35
virtual ~DeviceAllocatorTy () = default ;
32
36
33
37
// / Allocate a memory of size \p Size . \p HstPtr is used to assist the
34
38
// / allocation.
35
- virtual void *allocate (size_t Size, void *HstPtr,
36
- TargetAllocTy Kind = TARGET_ALLOC_DEFAULT) = 0;
39
+ virtual Expected<void *>
40
+ allocate (size_t Size, void *HstPtr,
41
+ TargetAllocTy Kind = TARGET_ALLOC_DEFAULT) = 0 ;
37
42
38
- virtual int free (void *TgtPtr, TargetAllocTy Kind = TARGET_ALLOC_DEFAULT) = 0;
43
+ // / Delete the pointer \p TgtPtr on the device
44
+ virtual Error free (void *TgtPtr,
45
+ TargetAllocTy Kind = TARGET_ALLOC_DEFAULT) = 0;
39
46
};
40
47
41
48
// / Class of memory manager. The memory manager is per-device by using
@@ -133,17 +140,17 @@ class MemoryManagerTy {
133
140
size_t SizeThreshold = 1U << 13 ;
134
141
135
142
// / Request memory from target device
136
- void *allocateOnDevice (size_t Size, void *HstPtr) const {
143
+ Expected< void *> allocateOnDevice (size_t Size, void *HstPtr) const {
137
144
return DeviceAllocator.allocate (Size, HstPtr, TARGET_ALLOC_DEVICE);
138
145
}
139
146
140
147
// / Deallocate data on device
141
- int deleteOnDevice (void *Ptr) const { return DeviceAllocator.free (Ptr); }
148
+ Error deleteOnDevice (void *Ptr) const { return DeviceAllocator.free (Ptr); }
142
149
143
150
// / This function is called when it tries to allocate memory on device but the
144
151
// / device returns out of memory. It will first free all memory in the
145
152
// / FreeList and try to allocate again.
146
- void *freeAndAllocate (size_t Size, void *HstPtr) {
153
+ Expected< void *> freeAndAllocate (size_t Size, void *HstPtr) {
147
154
std::vector<void *> RemoveList;
148
155
149
156
// Deallocate all memory in FreeList
@@ -153,7 +160,8 @@ class MemoryManagerTy {
153
160
if (List.empty ())
154
161
continue ;
155
162
for (const NodeTy &N : List) {
156
- deleteOnDevice (N.Ptr );
163
+ if (auto Err = deleteOnDevice (N.Ptr ))
164
+ return Err;
157
165
RemoveList.push_back (N.Ptr );
158
166
}
159
167
FreeLists[I].clear ();
@@ -174,14 +182,22 @@ class MemoryManagerTy {
174
182
// / allocate directly on the device. If a \p nullptr is returned, it might
175
183
// / be because the device is OOM. In that case, it will free all unused
176
184
// / memory and then try again.
177
- void *allocateOrFreeAndAllocateOnDevice (size_t Size, void *HstPtr) {
178
- void *TgtPtr = allocateOnDevice (Size, HstPtr);
185
+ Expected<void *> allocateOrFreeAndAllocateOnDevice (size_t Size,
186
+ void *HstPtr) {
187
+ auto TgtPtrOrErr = allocateOnDevice (Size, HstPtr);
188
+ if (!TgtPtrOrErr)
189
+ return TgtPtrOrErr.takeError ();
190
+
191
+ void *TgtPtr = *TgtPtrOrErr;
179
192
// We cannot get memory from the device. It might be due to OOM. Let's
180
193
// free all memory in FreeLists and try again.
181
194
if (TgtPtr == nullptr ) {
182
195
DP (" Failed to get memory on device. Free all memory in FreeLists and "
183
196
" try again.\n " );
184
- TgtPtr = freeAndAllocate (Size, HstPtr);
197
+ TgtPtrOrErr = freeAndAllocate (Size, HstPtr);
198
+ if (!TgtPtrOrErr)
199
+ return TgtPtrOrErr.takeError ();
200
+ TgtPtr = *TgtPtrOrErr;
185
201
}
186
202
187
203
if (TgtPtr == nullptr )
@@ -203,16 +219,17 @@ class MemoryManagerTy {
203
219
204
220
// / Destructor
205
221
~MemoryManagerTy () {
206
- for (auto Itr = PtrToNodeTable.begin (); Itr != PtrToNodeTable.end ();
207
- ++Itr) {
208
- assert (Itr->second .Ptr && " nullptr in map table" );
209
- deleteOnDevice (Itr->second .Ptr );
222
+ for (auto &PtrToNode : PtrToNodeTable) {
223
+ assert (PtrToNode.second .Ptr && " nullptr in map table" );
224
+ if (auto Err = deleteOnDevice (PtrToNode.second .Ptr ))
225
+ REPORT (" Failure to delete memory: %s\n " ,
226
+ toString (std::move (Err)).data ());
210
227
}
211
228
}
212
229
213
230
// / Allocate memory of size \p Size from target device. \p HstPtr is used to
214
231
// / assist the allocation.
215
- void *allocate (size_t Size, void *HstPtr) {
232
+ Expected< void *> allocate (size_t Size, void *HstPtr) {
216
233
// If the size is zero, we will not bother the target device. Just return
217
234
// nullptr directly.
218
235
if (Size == 0 )
@@ -227,11 +244,14 @@ class MemoryManagerTy {
227
244
DP (" %zu is greater than the threshold %zu. Allocate it directly from "
228
245
" device\n " ,
229
246
Size, SizeThreshold);
230
- void *TgtPtr = allocateOrFreeAndAllocateOnDevice (Size, HstPtr);
247
+ auto TgtPtrOrErr = allocateOrFreeAndAllocateOnDevice (Size, HstPtr);
248
+ if (!TgtPtrOrErr)
249
+ return TgtPtrOrErr.takeError ();
231
250
232
- DP (" Got target pointer " DPxMOD " . Return directly.\n " , DPxPTR (TgtPtr));
251
+ DP (" Got target pointer " DPxMOD " . Return directly.\n " ,
252
+ DPxPTR (*TgtPtrOrErr));
233
253
234
- return TgtPtr ;
254
+ return *TgtPtrOrErr ;
235
255
}
236
256
237
257
NodeTy *NodePtr = nullptr ;
@@ -259,8 +279,11 @@ class MemoryManagerTy {
259
279
if (NodePtr == nullptr ) {
260
280
DP (" Cannot find a node in the FreeLists. Allocate on device.\n " );
261
281
// Allocate one on device
262
- void *TgtPtr = allocateOrFreeAndAllocateOnDevice (Size, HstPtr);
282
+ auto TgtPtrOrErr = allocateOrFreeAndAllocateOnDevice (Size, HstPtr);
283
+ if (!TgtPtrOrErr)
284
+ return TgtPtrOrErr.takeError ();
263
285
286
+ void *TgtPtr = *TgtPtrOrErr;
264
287
if (TgtPtr == nullptr )
265
288
return nullptr ;
266
289
@@ -281,7 +304,7 @@ class MemoryManagerTy {
281
304
}
282
305
283
306
// / Deallocate memory pointed by \p TgtPtr
284
- int free (void *TgtPtr) {
307
+ Error free (void *TgtPtr) {
285
308
DP (" MemoryManagerTy::free: target memory " DPxMOD " .\n " , DPxPTR (TgtPtr));
286
309
287
310
NodeTy *P = nullptr ;
@@ -313,7 +336,7 @@ class MemoryManagerTy {
313
336
FreeLists[B].insert (*P);
314
337
}
315
338
316
- return OFFLOAD_SUCCESS ;
339
+ return Error::success () ;
317
340
}
318
341
319
342
// / Get the size threshold from the environment variable
@@ -343,4 +366,6 @@ class MemoryManagerTy {
343
366
constexpr const size_t MemoryManagerTy::BucketSize[];
344
367
constexpr const int MemoryManagerTy::NumBuckets;
345
368
369
+ } // namespace llvm
370
+
346
371
#endif // LLVM_OPENMP_LIBOMPTARGET_PLUGINS_COMMON_MEMORYMANAGER_H
0 commit comments