Skip to content

Commit 51e3c3d

Browse files
authored
[Offload] Implement 'olIsValidBinary' in offload and clean up (llvm#159658)
Summary: This exposes the 'isDeviceCompatible' routine for checking if a binary *can* be loaded. This is useful if people don't want to consume errors everywhere when figuring out which image to put to what device. I don't know if this is a good name, I was thining like `olIsCompatible` or whatever. Let me know what you think. Long term I'd like to be able to do something similar to what OpenMP does where we can conditionally only initialize devices if we need them. That's going to be support needed if we want this to be more generic.
1 parent bf83516 commit 51e3c3d

File tree

7 files changed

+86
-21
lines changed

7 files changed

+86
-21
lines changed

offload/liboffload/API/Program.td

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,18 @@ def olCreateProgram : Function {
2424
let returns = [];
2525
}
2626

27+
def olIsValidBinary : Function {
28+
let desc = "Validate if the binary image pointed to by `ProgData` is compatible with the device.";
29+
let details = ["The provided `ProgData` will not be loaded onto the device"];
30+
let params = [
31+
Param<"ol_device_handle_t", "Device", "handle of the device", PARAM_IN>,
32+
Param<"const void*", "ProgData", "pointer to the program binary data", PARAM_IN>,
33+
Param<"size_t", "ProgDataSize", "size of the program binary in bytes", PARAM_IN>,
34+
Param<"bool*", "Valid", "output is true if the image is compatible", PARAM_OUT>
35+
];
36+
let returns = [];
37+
}
38+
2739
def olDestroyProgram : Function {
2840
let desc = "Destroy the program and free all underlying resources.";
2941
let details = [];

offload/liboffload/src/OffloadImpl.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -887,7 +887,6 @@ Error olMemFill_impl(ol_queue_handle_t Queue, void *Ptr, size_t PatternSize,
887887

888888
Error olCreateProgram_impl(ol_device_handle_t Device, const void *ProgData,
889889
size_t ProgDataSize, ol_program_handle_t *Program) {
890-
// Make a copy of the program binary in case it is released by the caller.
891890
StringRef Buffer(reinterpret_cast<const char *>(ProgData), ProgDataSize);
892891
Expected<plugin::DeviceImageTy *> Res =
893892
Device->Device->loadBinary(Device->Device->Plugin, Buffer);
@@ -899,6 +898,14 @@ Error olCreateProgram_impl(ol_device_handle_t Device, const void *ProgData,
899898
return Error::success();
900899
}
901900

901+
Error olIsValidBinary_impl(ol_device_handle_t Device, const void *ProgData,
902+
size_t ProgDataSize, bool *IsValid) {
903+
StringRef Buffer(reinterpret_cast<const char *>(ProgData), ProgDataSize);
904+
*IsValid = Device->Device->Plugin.isDeviceCompatible(
905+
Device->Device->getDeviceId(), Buffer);
906+
return Error::success();
907+
}
908+
902909
Error olDestroyProgram_impl(ol_program_handle_t Program) {
903910
auto &Device = Program->Image->getDevice();
904911
if (auto Err = Device.unloadBinary(Program->Image))

offload/libomptarget/PluginManager.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,10 @@ void PluginManager::registerLib(__tgt_bin_desc *Desc) {
219219
// Scan the RTLs that have associated images until we find one that supports
220220
// the current image.
221221
for (auto &R : plugins()) {
222-
if (!R.is_plugin_compatible(Img))
222+
StringRef Buffer(reinterpret_cast<const char *>(Img->ImageStart),
223+
utils::getPtrDiff(Img->ImageEnd, Img->ImageStart));
224+
225+
if (!R.isPluginCompatible(Buffer))
223226
continue;
224227

225228
if (!initializePlugin(R))
@@ -242,7 +245,7 @@ void PluginManager::registerLib(__tgt_bin_desc *Desc) {
242245
continue;
243246
}
244247

245-
if (!R.is_device_compatible(DeviceId, Img))
248+
if (!R.isDeviceCompatible(DeviceId, Buffer))
246249
continue;
247250

248251
DP("Image " DPxMOD " is compatible with RTL %s device %d!\n",

offload/plugins-nextgen/common/include/PluginInterface.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1378,10 +1378,10 @@ struct GenericPluginTy {
13781378

13791379
/// Returns non-zero if the \p Image is compatible with the plugin. This
13801380
/// function does not require the plugin to be initialized before use.
1381-
int32_t is_plugin_compatible(__tgt_device_image *Image);
1381+
int32_t isPluginCompatible(StringRef Image);
13821382

13831383
/// Returns non-zero if the \p Image is compatible with the device.
1384-
int32_t is_device_compatible(int32_t DeviceId, __tgt_device_image *Image);
1384+
int32_t isDeviceCompatible(int32_t DeviceId, StringRef Image);
13851385

13861386
/// Returns non-zero if the plugin device has been initialized.
13871387
int32_t is_device_initialized(int32_t DeviceId) const;

offload/plugins-nextgen/common/src/PluginInterface.cpp

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1713,28 +1713,25 @@ Expected<bool> GenericPluginTy::checkBitcodeImage(StringRef Image) const {
17131713

17141714
int32_t GenericPluginTy::is_initialized() const { return Initialized; }
17151715

1716-
int32_t GenericPluginTy::is_plugin_compatible(__tgt_device_image *Image) {
1717-
StringRef Buffer(reinterpret_cast<const char *>(Image->ImageStart),
1718-
utils::getPtrDiff(Image->ImageEnd, Image->ImageStart));
1719-
1716+
int32_t GenericPluginTy::isPluginCompatible(StringRef Image) {
17201717
auto HandleError = [&](Error Err) -> bool {
17211718
[[maybe_unused]] std::string ErrStr = toString(std::move(Err));
17221719
DP("Failure to check validity of image %p: %s", Image, ErrStr.c_str());
17231720
return false;
17241721
};
1725-
switch (identify_magic(Buffer)) {
1722+
switch (identify_magic(Image)) {
17261723
case file_magic::elf:
17271724
case file_magic::elf_relocatable:
17281725
case file_magic::elf_executable:
17291726
case file_magic::elf_shared_object:
17301727
case file_magic::elf_core: {
1731-
auto MatchOrErr = checkELFImage(Buffer);
1728+
auto MatchOrErr = checkELFImage(Image);
17321729
if (Error Err = MatchOrErr.takeError())
17331730
return HandleError(std::move(Err));
17341731
return *MatchOrErr;
17351732
}
17361733
case file_magic::bitcode: {
1737-
auto MatchOrErr = checkBitcodeImage(Buffer);
1734+
auto MatchOrErr = checkBitcodeImage(Image);
17381735
if (Error Err = MatchOrErr.takeError())
17391736
return HandleError(std::move(Err));
17401737
return *MatchOrErr;
@@ -1744,36 +1741,32 @@ int32_t GenericPluginTy::is_plugin_compatible(__tgt_device_image *Image) {
17441741
}
17451742
}
17461743

1747-
int32_t GenericPluginTy::is_device_compatible(int32_t DeviceId,
1748-
__tgt_device_image *Image) {
1749-
StringRef Buffer(reinterpret_cast<const char *>(Image->ImageStart),
1750-
utils::getPtrDiff(Image->ImageEnd, Image->ImageStart));
1751-
1744+
int32_t GenericPluginTy::isDeviceCompatible(int32_t DeviceId, StringRef Image) {
17521745
auto HandleError = [&](Error Err) -> bool {
17531746
[[maybe_unused]] std::string ErrStr = toString(std::move(Err));
17541747
DP("Failure to check validity of image %p: %s", Image, ErrStr.c_str());
17551748
return false;
17561749
};
1757-
switch (identify_magic(Buffer)) {
1750+
switch (identify_magic(Image)) {
17581751
case file_magic::elf:
17591752
case file_magic::elf_relocatable:
17601753
case file_magic::elf_executable:
17611754
case file_magic::elf_shared_object:
17621755
case file_magic::elf_core: {
1763-
auto MatchOrErr = checkELFImage(Buffer);
1756+
auto MatchOrErr = checkELFImage(Image);
17641757
if (Error Err = MatchOrErr.takeError())
17651758
return HandleError(std::move(Err));
17661759
if (!*MatchOrErr)
17671760
return false;
17681761

17691762
// Perform plugin-dependent checks for the specific architecture if needed.
1770-
auto CompatibleOrErr = isELFCompatible(DeviceId, Buffer);
1763+
auto CompatibleOrErr = isELFCompatible(DeviceId, Image);
17711764
if (Error Err = CompatibleOrErr.takeError())
17721765
return HandleError(std::move(Err));
17731766
return *CompatibleOrErr;
17741767
}
17751768
case file_magic::bitcode: {
1776-
auto MatchOrErr = checkBitcodeImage(Buffer);
1769+
auto MatchOrErr = checkBitcodeImage(Image);
17771770
if (Error Err = MatchOrErr.takeError())
17781771
return HandleError(std::move(Err));
17791772
return *MatchOrErr;

offload/unittests/OffloadAPI/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ add_offload_unittest("platform"
3535

3636
add_offload_unittest("program"
3737
program/olCreateProgram.cpp
38+
program/olIsValidBinary.cpp
3839
program/olDestroyProgram.cpp)
3940

4041
add_offload_unittest("queue"
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
//===------- Offload API tests - olIsValidBinary --------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "../common/Fixtures.hpp"
10+
#include <OffloadAPI.h>
11+
#include <gtest/gtest.h>
12+
13+
using olIsValidBinaryTest = OffloadDeviceTest;
14+
OFFLOAD_TESTS_INSTANTIATE_DEVICE_FIXTURE(olIsValidBinaryTest);
15+
16+
TEST_P(olIsValidBinaryTest, Success) {
17+
18+
std::unique_ptr<llvm::MemoryBuffer> DeviceBin;
19+
ASSERT_TRUE(TestEnvironment::loadDeviceBinary("foo", Device, DeviceBin));
20+
ASSERT_GE(DeviceBin->getBufferSize(), 0lu);
21+
22+
bool IsValid = false;
23+
ASSERT_SUCCESS(olIsValidBinary(Device, DeviceBin->getBufferStart(),
24+
DeviceBin->getBufferSize(), &IsValid));
25+
ASSERT_TRUE(IsValid);
26+
27+
ASSERT_SUCCESS(
28+
olIsValidBinary(Device, DeviceBin->getBufferStart(), 0, &IsValid));
29+
ASSERT_FALSE(IsValid);
30+
}
31+
32+
TEST_P(olIsValidBinaryTest, Invalid) {
33+
34+
std::unique_ptr<llvm::MemoryBuffer> DeviceBin;
35+
ASSERT_TRUE(TestEnvironment::loadDeviceBinary("foo", Device, DeviceBin));
36+
ASSERT_GE(DeviceBin->getBufferSize(), 0lu);
37+
38+
bool IsValid = false;
39+
ASSERT_SUCCESS(
40+
olIsValidBinary(Device, DeviceBin->getBufferStart(), 0, &IsValid));
41+
ASSERT_FALSE(IsValid);
42+
}
43+
44+
TEST_P(olIsValidBinaryTest, NullPointer) {
45+
bool IsValid = false;
46+
ASSERT_ERROR(OL_ERRC_INVALID_NULL_POINTER,
47+
olIsValidBinary(Device, nullptr, 42, &IsValid));
48+
ASSERT_FALSE(IsValid);
49+
}

0 commit comments

Comments
 (0)