diff --git a/README.md b/README.md index 9415a56..17d54fb 100644 --- a/README.md +++ b/README.md @@ -11,9 +11,10 @@ try to exploit benchmarking flaws to receive higher scores. To benchmark a kernel, two ingredients are needed: 1. The qualified name of the kernel function. It is important that the testing script itself does not import the kernel function, as this implies executing untrusted code. 2. A function that generates test/benchmark inputs. This function takes keyword arguments of configuration parameters, - as well as the reserved argument `seed` to randomize the problem. It returns two tuples: - The first contains the inputs for the kernel and will - be used to call the kernel function, and the second contains the expected output and the required absolute and relative tolerance. + as well as the reserved argument `seed` to randomize the problem. It returns the kernel arguments. + Any writable / checked output must be wrapped in `pygpubench.out(...)` together + with its expected result and optional tolerances. For in/out args whose initial contents matter, + pass `uses_current_value=True`. ```python import torch @@ -29,7 +30,10 @@ def generate_test_case(*, seed, **kwargs): x, y = generate_input(**kwargs, seed=seed) expected = torch.empty_like(y) reference_kernel((expected, x)) - return (y, x), (expected, 1e-6, 1e-6) + return ( + pygpubench.out(y, expected=(expected, 1e-6, 1e-6)), + x, + ) res = pygpubench.do_bench_isolated("submission.kernel", generate_test_case, {"size": 1024}, 100, 5, discard=True) diff --git a/csrc/binding.cpp b/csrc/binding.cpp index 8084456..13788b9 100644 --- a/csrc/binding.cpp +++ b/csrc/binding.cpp @@ -18,8 +18,8 @@ void do_bench(int result_fd, int input_fd, const std::string& kernel_qualname, c signature.allocate(32, rng); auto config = read_benchmark_parameters(input_fd, signature.data()); BenchmarkManager mgr(result_fd, std::move(signature), config.Seed, discard, nvtx, landlock, mseal); - auto [args, expected] = mgr.setup_benchmark(nb::cast(test_generator), test_kwargs, config.Repeats); - mgr.do_bench_py(kernel_qualname, args, expected, reinterpret_cast(stream)); + auto [args, output_positions, input_output_positions, expected] = mgr.setup_benchmark(nb::cast(test_generator), test_kwargs, config.Repeats); + mgr.do_bench_py(kernel_qualname, args, output_positions, input_output_positions, expected, reinterpret_cast(stream)); } diff --git a/csrc/manager.cpp b/csrc/manager.cpp index 13fa001..1404dcf 100644 --- a/csrc/manager.cpp +++ b/csrc/manager.cpp @@ -146,14 +146,18 @@ BenchmarkManager::~BenchmarkManager() { cudaFree(mDeviceErrorBase); for (auto& event : mStartEvents) cudaEventDestroy(event); for (auto& event : mEndEvents) cudaEventDestroy(event); - for (auto& exp: mExpectedOutputs) cudaFree(exp.Value); + for (auto& expected_per_test : mExpectedOutputs) { + for (auto& exp : expected_per_test) cudaFree(exp.Value); + } } -std::pair, std::vector> BenchmarkManager::setup_benchmark(const nb::callable& generate_test_case, const nb::dict& kwargs, int repeats) { +std::tuple, std::vector>, std::vector>, std::vector> BenchmarkManager::setup_benchmark(const nb::callable& generate_test_case, const nb::dict& kwargs, int repeats) { std::mt19937_64 rng(mSeed); std::uniform_int_distribution dist(0, std::numeric_limits::max()); // generate one more input to handle warmup - std::vector kernel_args(repeats + 1); + std::vector call_args(repeats + 1); + std::vector> output_positions(repeats + 1); + std::vector> input_output_positions(repeats + 1); std::vector expected(repeats + 1); for (int i = 0; i < repeats + 1; i++) { // create new copy of the kwargs dict @@ -168,23 +172,74 @@ std::pair, std::vector> BenchmarkManager::setu call_kwargs["seed"] = dist(rng); auto gen = nb::cast(generate_test_case(**call_kwargs)); - kernel_args[i] = nb::cast(gen[0]); - expected[i] = nb::cast(gen[1]); + if (gen.size() != 4) { + throw std::runtime_error("generate_test_case must return a 4-tuple: (args, output_positions, input_output_positions, expected)"); + } + + call_args[i] = nb::cast(gen[0]); + nb::tuple output_positions_tuple = nb::cast(gen[1]); + nb::tuple input_output_positions_tuple = nb::cast(gen[2]); + expected[i] = nb::cast(gen[3]); + + if (output_positions_tuple.size() == 0) { + throw std::runtime_error("output_positions tuple must not be empty"); + } + if (expected[i].size() != output_positions_tuple.size()) { + throw std::runtime_error("expected tuple size must match output_positions tuple size"); + } + std::vector seen_output(call_args[i].size(), false); + output_positions[i].reserve(output_positions_tuple.size()); + for (int j = 0; j < output_positions_tuple.size(); j++) { + std::size_t pos = nb::cast(output_positions_tuple[j]); + if (pos >= static_cast(call_args[i].size())) { + throw std::runtime_error("output_positions contains an index outside the args tuple"); + } + if (seen_output[pos]) { + throw std::runtime_error("output_positions contains duplicate indices"); + } + seen_output[pos] = true; + output_positions[i].push_back(pos); + } + std::vector seen_input_output(call_args[i].size(), false); + input_output_positions[i].reserve(input_output_positions_tuple.size()); + for (int j = 0; j < input_output_positions_tuple.size(); j++) { + std::size_t pos = nb::cast(input_output_positions_tuple[j]); + if (pos >= static_cast(call_args[i].size())) { + throw std::runtime_error("input_output_positions contains an index outside the args tuple"); + } + if (!seen_output[pos]) { + throw std::runtime_error("input_output_positions must be a subset of output_positions"); + } + if (seen_input_output[pos]) { + throw std::runtime_error("input_output_positions contains duplicate indices"); + } + seen_input_output[pos] = true; + input_output_positions[i].push_back(pos); + } } - return std::make_pair(std::move(kernel_args), std::move(expected)); + return std::make_tuple(std::move(call_args), std::move(output_positions), std::move(input_output_positions), std::move(expected)); } bool can_convert_to_tensor(nb::handle obj) { return nb::isinstance(obj); } -auto BenchmarkManager::make_shadow_args(const nb::tuple& args, cudaStream_t stream) -> std::vector> { +auto BenchmarkManager::make_shadow_args(const nb::tuple& args, const std::vector& output_positions, const std::vector& input_output_positions, cudaStream_t stream) -> std::vector> { std::vector> shadow_args(args.size()); - int nargs = args.size(); + std::vector is_output(args.size(), false); + for (auto pos : output_positions) { + is_output.at(pos) = true; + } + for (auto pos : input_output_positions) { + is_output.at(pos) = false; + } std::random_device rd; std::mt19937 gen(rd()); std::uniform_int_distribution canary_seed_dist(0, 0xffffffff); - for (int i = 1; i < nargs; i++) { + for (std::size_t i = 0; i < static_cast(args.size()); i++) { + if (is_output[i]) { + continue; + } if (can_convert_to_tensor(args[i])) { nb_cuda_array arr = nb::cast(args[i]); void* shadow; @@ -225,6 +280,39 @@ void BenchmarkManager::validate_result(Expected& expected, const nb_cuda_array& } } +BenchmarkManager::Expected BenchmarkManager::parse_expected_spec(const nb::handle& obj) { + nb_cuda_array expected_array; + auto mode = BenchmarkManager::Expected::ExactMatch; + float rtol = 0.f; + float atol = 0.f; + + if (nb::isinstance(obj)) { + expected_array = nb::cast(obj); + } else { + nb::tuple expected_tuple = nb::cast(obj); + if (expected_tuple.size() == 0) { + throw std::runtime_error("Expected spec tuple must not be empty"); + } + if (expected_tuple.size() != 1 && expected_tuple.size() != 3) { + throw std::runtime_error("Expected spec tuple must have size 1 or 3"); + } + expected_array = nb::cast(expected_tuple[0]); + if (expected_tuple.size() == 3) { + rtol = nb::cast(expected_tuple[1]); + atol = nb::cast(expected_tuple[2]); + mode = BenchmarkManager::Expected::ApproxMatch; + } + } + + // copy expected values into memory not owned by torch, then wipe original + void* copy_mem; + CUDA_CHECK(cudaMalloc(©_mem, expected_array.nbytes())); + CUDA_CHECK(cudaMemcpy(copy_mem, expected_array.data(), expected_array.nbytes(), cudaMemcpyDeviceToDevice)); + CUDA_CHECK(cudaMemset(expected_array.data(), 0, expected_array.nbytes())); + + return {mode, copy_mem, expected_array.nbytes(), expected_array.dtype(), atol, rtol}; +} + void BenchmarkManager::clear_cache(cudaStream_t stream) { ::clear_cache(mDeviceDummyMemory, 2 * mL2CacheSize, mDiscardCache, stream); } @@ -250,35 +338,37 @@ BenchmarkManager::ShadowArgument& BenchmarkManager::ShadowArgument::operator=(Sh } void BenchmarkManager::do_bench_py( - const std::string& kernel_qualname, - const std::vector& args, - const std::vector& expected, - cudaStream_t stream) -{ + const std::string& kernel_qualname, + const std::vector& args, + const std::vector>& output_positions, + const std::vector>& input_output_positions, + const std::vector& expected, + cudaStream_t stream +) { if (args.size() < 5) { throw std::runtime_error("Not enough test cases to run benchmark"); } - if (expected.size() != args.size()) { - throw std::runtime_error("Expected results and test case list do not have the same length"); + if (output_positions.size() != args.size() || input_output_positions.size() != args.size() || expected.size() != args.size()) { + throw std::runtime_error("Expected results, output metadata, and test case lists do not have the same length"); } int calls = args.size() - 1; - // extract relevant infos from args and expected - // by convention, the first arg is the output tensor. - // TODO handle multiple outputs - std::vector outputs(args.size()); + // extract relevant infos from outputs and expected + std::vector> outputs(args.size()); for (int i = 0; i < args.size(); i++) { - outputs.at(i) = nb::cast(args.at(i)[0]); + outputs.at(i).reserve(output_positions.at(i).size()); + for (auto pos : output_positions.at(i)) { + outputs.at(i).push_back(nb::cast(args.at(i)[pos])); + } } // Generate "shadow" copies of input arguments std::vector shadow_arguments; - for (const auto & arg : args) { - shadow_arguments.emplace_back(make_shadow_args(arg, stream)); + for (int i = 0; i < args.size(); i++) { + shadow_arguments.emplace_back(make_shadow_args(args.at(i), output_positions.at(i), input_output_positions.at(i), stream)); } - // prepare expected outputs - setup_expected_outputs(args, expected); + setup_expected_outputs(output_positions, expected); // clean up as much python state as we can trigger_gc(); @@ -300,9 +390,28 @@ void BenchmarkManager::do_bench_py( // after this, we cannot trust python anymore nb::callable kernel = kernel_from_qualname(kernel_qualname); + auto prepare_args = [&](const ShadowArgumentList& shadow_args) { + for (auto& shadow_arg : shadow_args) { + if (shadow_arg) { + CUDA_CHECK(cudaMemcpyAsync(shadow_arg->Original.data(), shadow_arg->Shadow, shadow_arg->Original.nbytes(), cudaMemcpyDeviceToDevice, stream)); + } + } + + clear_cache(stream); + + // ok, now we revert the canaries. This _does_ bring in the corresponding cache lines, + // but they are very sparse (1/256), so that seems like an acceptable trade-off + for (auto& shadow_arg : shadow_args) { + if (shadow_arg) { + canaries(shadow_arg->Original.data(), shadow_arg->Original.nbytes(), shadow_arg->Seed, stream); + } + } + }; + // ok, first run for compilations etc nvtx_push("warmup"); CUDA_CHECK(cudaDeviceSynchronize()); + prepare_args(shadow_arguments.at(0)); kernel(*args.at(0)); CUDA_CHECK(cudaDeviceSynchronize()); nvtx_pop(); @@ -316,7 +425,7 @@ void BenchmarkManager::do_bench_py( // note: we are assuming here that calling the kernel multiple times for the same input is a safe operation // this is only potentially problematic for in-place kernels; CUDA_CHECK(cudaDeviceSynchronize()); - clear_cache(stream); + prepare_args(shadow_arguments.at(0)); kernel(*args.at(0)); CUDA_CHECK(cudaDeviceSynchronize()); std::chrono::high_resolution_clock::time_point cpu_end = std::chrono::high_resolution_clock::now(); @@ -379,24 +488,10 @@ void BenchmarkManager::do_bench_py( // unfortunately, we need to do this before clearing the cache, so there is a window of opportunity // *but* we deliberately modify a small subset of the inputs, which only get corrected immediately before // the user code call. - for (auto& shadow_arg : shadow_arguments.at(test_id)) { - if (shadow_arg) { - CUDA_CHECK(cudaMemcpyAsync(shadow_arg->Original.data(), shadow_arg->Shadow, shadow_arg->Original.nbytes(), cudaMemcpyDeviceToDevice, stream)); - } - } - nvtx_push("cc"); - clear_cache(stream); + prepare_args(shadow_arguments.at(test_id)); nvtx_pop(); - // ok, now we revert the canaries. This _does_ bring in the corresponding cache lines, - // but they are very sparse (1/256), so that seems like an acceptable trade-off - for (auto& shadow_arg : shadow_arguments.at(test_id)) { - if (shadow_arg) { - canaries(shadow_arg->Original.data(), shadow_arg->Original.nbytes(), shadow_arg->Seed, stream); - } - } - CUDA_CHECK(cudaEventRecord(mStartEvents.at(i), stream)); nvtx_push("kernel"); (void)kernel(*args.at(test_id)); @@ -404,7 +499,9 @@ void BenchmarkManager::do_bench_py( CUDA_CHECK(cudaEventRecord(mEndEvents.at(i), stream)); // immediately after the kernel, launch the checking code; if there is some unsynced work done on another stream, // this increases the chance of detection. - validate_result(mExpectedOutputs.at(test_id), outputs.at(test_id), check_seed_generator(rng), stream); + for (std::size_t j = 0; j < outputs.at(test_id).size(); j++) { + validate_result(mExpectedOutputs.at(test_id).at(j), outputs.at(test_id).at(j), check_seed_generator(rng), stream); + } } nvtx_pop(); @@ -456,25 +553,20 @@ float BenchmarkManager::measure_event_overhead(int repeats, cudaStream_t stream) return median; } -void BenchmarkManager::setup_expected_outputs(const std::vector& args, const std::vector& expected) { - mExpectedOutputs.resize(args.size()); - for (int i = 0; i < args.size(); i++) { +void BenchmarkManager::setup_expected_outputs(const std::vector>& output_positions, const std::vector& expected) { + for (auto& expected_per_test : mExpectedOutputs) { + for (auto& exp : expected_per_test) cudaFree(exp.Value); + } + mExpectedOutputs.clear(); + mExpectedOutputs.resize(output_positions.size()); + for (int i = 0; i < output_positions.size(); i++) { const nb::tuple& expected_tuple = expected.at(i); - nb_cuda_array expected_array = nb::cast(expected_tuple[0]); - - // make a copy of the expected result and put it in memory not owned by torch; overwrite the original - // so it cannot be read by cheating solutions. - void* copy_mem; - CUDA_CHECK(cudaMalloc(©_mem, expected_array.nbytes())); - CUDA_CHECK(cudaMemcpy(copy_mem, expected_array.data(), expected_array.nbytes(), cudaMemcpyDeviceToDevice)); - CUDA_CHECK(cudaMemset(expected_array.data(), 0, expected_array.nbytes())); - - if (expected.at(i).size() == 1) { - mExpectedOutputs.at(i) = {Expected::ExactMatch, copy_mem, expected_array.nbytes(), expected_array.dtype(), 0.f, 0.f}; - } else { - float rtol = nb::cast(expected_tuple[1]); - float atol = nb::cast(expected_tuple[2]); - mExpectedOutputs.at(i) = {Expected::ApproxMatch, copy_mem, expected_array.nbytes(), expected_array.dtype(), atol, rtol}; + if (expected_tuple.size() != output_positions.at(i).size()) { + throw std::runtime_error("Expected tuple size must match output_positions tuple size"); + } + mExpectedOutputs.at(i).reserve(expected_tuple.size()); + for (int j = 0; j < expected_tuple.size(); j++) { + mExpectedOutputs.at(i).push_back(parse_expected_spec(expected_tuple[j])); } } -} \ No newline at end of file +} diff --git a/csrc/manager.h b/csrc/manager.h index 1943049..6913789 100644 --- a/csrc/manager.h +++ b/csrc/manager.h @@ -9,6 +9,7 @@ #include #include #include +#include #include #include #include @@ -31,8 +32,16 @@ class BenchmarkManager { public: BenchmarkManager(int result_fd, ObfuscatedHexDigest signature, std::uint64_t seed, bool discard, bool nvtx, bool landlock, bool mseal); ~BenchmarkManager(); - std::pair, std::vector> setup_benchmark(const nb::callable& generate_test_case, const nb::dict& kwargs, int repeats); - void do_bench_py(const std::string& kernel_qualname, const std::vector& args, const std::vector& expected, cudaStream_t stream); + std::tuple, std::vector>, std::vector>, std::vector> + setup_benchmark(const nb::callable& generate_test_case, const nb::dict& kwargs, int repeats); + void do_bench_py( + const std::string& kernel_qualname, + const std::vector& args, + const std::vector>& output_positions, + const std::vector>& input_output_positions, + const std::vector& expected, + cudaStream_t stream + ); private: struct Expected { enum EMode { @@ -76,12 +85,13 @@ class BenchmarkManager { bool mLandlock = true; bool mSeal = true; std::uint64_t mSeed = -1; - std::vector mExpectedOutputs; + std::vector> mExpectedOutputs; FILE* mOutputPipe = nullptr; ObfuscatedHexDigest mSignature; - static ShadowArgumentList make_shadow_args(const nb::tuple& args, cudaStream_t stream); + static ShadowArgumentList make_shadow_args(const nb::tuple& args, const std::vector& output_positions, const std::vector& input_output_positions, cudaStream_t stream); + static Expected parse_expected_spec(const nb::handle& obj); void nvtx_push(const char* name); void nvtx_pop(); @@ -89,7 +99,7 @@ class BenchmarkManager { void validate_result(Expected& expected, const nb_cuda_array& result, unsigned seed, cudaStream_t stream); void clear_cache(cudaStream_t stream); float measure_event_overhead(int repeats, cudaStream_t stream); - void setup_expected_outputs(const std::vector& args, const std::vector& expected); + void setup_expected_outputs(const std::vector>& output_positions, const std::vector& expected); }; #endif //PYGPUBENCH_SRC_MANAGER_H diff --git a/python/pygpubench/__init__.py b/python/pygpubench/__init__.py index 41e6320..53d5c7c 100644 --- a/python/pygpubench/__init__.py +++ b/python/pygpubench/__init__.py @@ -16,17 +16,54 @@ __all__ = [ + "BenchmarkCase", "do_bench_isolated", "basic_stats", "BenchmarkResult", "BenchmarkStats", "DeterministicContext", "KernelFunction", + "OutputArg", "TestGeneratorInterface", + "ExpectedSpec", "ExpectedResult", + "out", ] +def out(value, *, expected, uses_current_value=False): + """Mark a writable / checked kernel argument.""" + return OutputArg( + value=value, + expected=expected, + uses_current_value=uses_current_value, + ) + + +def _normalize_test_case(case: BenchmarkCase) -> tuple[tuple, tuple[int, ...], tuple[int, ...], ExpectedResult]: + if not isinstance(case, tuple): + raise RuntimeError("generate_test_case must return a tuple of kernel arguments") + + kernel_args = [] + output_positions = [] + input_output_positions = [] + expected = [] + for idx, arg in enumerate(case): + if isinstance(arg, OutputArg): + kernel_args.append(arg.value) + output_positions.append(idx) + if arg.uses_current_value: + input_output_positions.append(idx) + expected.append(arg.expected) + else: + kernel_args.append(arg) + + if not output_positions: + raise RuntimeError("generate_test_case must include at least one pygpubench.out(...) argument") + + return tuple(kernel_args), tuple(output_positions), tuple(input_output_positions), tuple(expected) + + def _do_bench_impl(out_fd: "multiprocessing.connection.Connection", in_fd: "multiprocessing.connection.Connection", qualname: str, test_generator: TestGeneratorInterface, test_args: dict, stream: int = None, discard: bool = True, nvtx: bool = False, tb_conn: "multiprocessing.connection.Connection" = None, landlock=True, mseal=True): @@ -35,7 +72,9 @@ def _do_bench_impl(out_fd: "multiprocessing.connection.Connection", in_fd: "mult :param out_fd: Writable file descriptor to which benchmark results are written. :param in_fd: Readable file descriptor that communicates benchmark configuration to the runner. :param qualname: Fully qualified name of the kernel object, e.g. ``my_package.my_module.kernel``. - :param test_generator: A function that takes the test arguments (including a seed) and returns a test case; i.e., a tuple of (input, expected) + :param test_generator: A function that takes the test arguments (including a seed) and returns + kernel arguments in call order. Writable / checked args must be wrapped in `pygpubench.out(...)`. + If an output also depends on its initial contents, pass `uses_current_value=True`. :param test_args: keyword arguments to be passed to `test_generator`. Seed will be generated automatically. :param discard: If true, then cache lines are discarded as part of cache clearing before each benchmark run. :param nvtx: Whether to enable NVTX markers for the benchmark. Mostly useful for debugging. @@ -47,13 +86,16 @@ def _do_bench_impl(out_fd: "multiprocessing.connection.Connection", in_fd: "mult import torch stream = torch.cuda.current_stream().cuda_stream + def normalized_test_generator(**kwargs): + return _normalize_test_case(test_generator(**kwargs)) + try: with DeterministicContext(): _pygpubench.do_bench( out_fd.fileno(), in_fd.fileno(), qualname, - test_generator, + normalized_test_generator, test_args, stream, discard, @@ -152,7 +194,11 @@ def do_bench_isolated( mseal = True, ) -> BenchmarkResult: """ - Runs kernel benchmark (`do_bench_impl`) in a subprocess for proper isolation. + Runs a kernel benchmark in a subprocess for proper isolation. + + `test_generator(...)` must return kernel arguments in call order. + Writable / checked arguments must be wrapped in `pygpubench.out(...)`. + If an output also depends on its initial contents, pass `uses_current_value=True`. """ assert repeats > 1 diff --git a/python/pygpubench/_types.py b/python/pygpubench/_types.py index e1d0424..6a996eb 100644 --- a/python/pygpubench/_types.py +++ b/python/pygpubench/_types.py @@ -1,9 +1,27 @@ -from typing import Callable, Tuple +import dataclasses + +from typing import Any, Callable, Tuple Tensor = "torch.Tensor" -ExpectedResult = Tuple[Tensor] | Tuple[Tensor, float, float] +ExpectedSpec = Tensor | Tuple[Tensor] | Tuple[Tensor, float, float] +ExpectedResult = Tuple[ExpectedSpec, ...] +BenchmarkCase = Tuple[Any, ...] + + +@dataclasses.dataclass(frozen=True) +class OutputArg: + value: Any + expected: ExpectedSpec + uses_current_value: bool = False KernelFunction = Callable[..., None] -TestGeneratorInterface = Callable[..., Tuple[Tuple, ExpectedResult]] +TestGeneratorInterface = Callable[..., BenchmarkCase] -__all__ = ["KernelFunction", "TestGeneratorInterface", "ExpectedResult"] +__all__ = [ + "BenchmarkCase", + "ExpectedResult", + "ExpectedSpec", + "KernelFunction", + "OutputArg", + "TestGeneratorInterface", +] diff --git a/test/grayscale.py b/test/grayscale.py index 21aa41b..26984aa 100644 --- a/test/grayscale.py +++ b/test/grayscale.py @@ -32,7 +32,10 @@ def generate_test_case(**kwargs): x, y = generate_input(**kwargs) expected = torch.empty_like(y) reference_kernel((expected, x)) - return (y, x), (expected, 1e-6, 1e-6) + return ( + pygpubench.out(y, expected=(expected, 1e-6, 1e-6)), + x, + ) # note: can't enable landlock or mseal when running on modal :( diff --git a/test/grayscale_multi.py b/test/grayscale_multi.py new file mode 100644 index 0000000..2ffde9b --- /dev/null +++ b/test/grayscale_multi.py @@ -0,0 +1,70 @@ +import pygpubench +import torch + + +def reference_kernel(data): + output_gray, output_red, data = data + weights = torch.tensor([0.2989, 0.5870, 0.1140], + device=data.device, + dtype=data.dtype) + output_gray[...] = torch.sum(data * weights, dim=-1) + output_red[...] = data[..., 0] + + +def generate_input(size: int, seed: int): + """ + Generates random RGB image tensor of the specified size. + Returns: + Tensor of shape (size, size, 3) with values in [0, 1] + """ + gen = torch.Generator(device="cuda") + gen.manual_seed(seed) + + x = torch.rand( + size, size, 3, device="cuda", dtype=torch.float32, generator=gen + ).contiguous() + + y_gray = torch.empty(size, size, device="cuda", dtype=torch.float32).contiguous() + y_red = torch.empty(size, size, device="cuda", dtype=torch.float32).contiguous() + + return x, y_gray, y_red + + +def generate_test_case(**kwargs): + x, y_gray, y_red = generate_input(**kwargs) + expected_gray = torch.empty_like(y_gray) + expected_red = torch.empty_like(y_red) + reference_kernel((expected_gray, expected_red, x)) + return ( + pygpubench.out(y_gray, expected=(expected_gray, 1e-6, 1e-6)), + pygpubench.out(y_red, expected=expected_red), + x, + ) +if __name__ == "__main__": + kernels = ["valid_custom_kernel_eager", "valid_custom_kernel_compiled", "valid_custom_kernel_stream"] + for kernel in kernels: + print(kernel) + res = pygpubench.do_bench_isolated( + f"submission_multi.{kernel}", + generate_test_case, + {"size": 1024}, + 100, + 5, + discard=True, + ) + print("❌" if res.errors else "✅", pygpubench.basic_stats(res.time_us)) + + broken = ["wrong_custom_kernel_backward_race", "wrong_custom_kernel_forward_race"] + for kernel in broken: + print(kernel) + res = pygpubench.do_bench_isolated( + f"submission_multi.{kernel}", + generate_test_case, + {"size": 1024}, + 100, + 5, + discard=True, + ) + print("❌" if res.errors else "✅", pygpubench.basic_stats(res.time_us)) + + print("done") diff --git a/test/submission_multi.py b/test/submission_multi.py new file mode 100644 index 0000000..6816eaf --- /dev/null +++ b/test/submission_multi.py @@ -0,0 +1,57 @@ +import torch + + +_weights = torch.tensor([0.2989, 0.5870, 0.1140], + device="cuda:0", + dtype=torch.float32) + + +stream = torch.cuda.Stream(device="cuda:0") +event = torch.cuda.Event(enable_timing=False) + + +def valid_custom_kernel_eager(output_gray, output_red, data): + torch.sum(data * _weights, dim=-1, out=output_gray) + output_red.copy_(data[..., 0]) + + +@torch.compile +def valid_custom_kernel_compiled(output_gray, output_red, data): + torch.sum(data * _weights, dim=-1, out=output_gray) + output_red.copy_(data[..., 0]) + + +def wrong_custom_kernel_backward_race(output_gray, output_red, data): + with torch.cuda.stream(stream): + torch.sum(data * _weights, dim=-1, out=output_gray) + output_red.copy_(data[..., 0]) + event.record() + event.synchronize() + + +def wrong_custom_kernel_forward_race(output_gray, output_red, data): + event.record() + with torch.cuda.stream(stream): + event.synchronize() + torch.sum(data * _weights, dim=-1, out=output_gray) + output_red.copy_(data[..., 0]) + + +def valid_custom_kernel_stream(output_gray, output_red, data): + event.record() + with torch.cuda.stream(stream): + event.synchronize() + torch.sum(data * _weights, dim=-1, out=output_gray) + output_red.copy_(data[..., 0]) + event.record() + event.synchronize() + + +def wrong_custom_kernel_sneaky(output_gray, output_red, data): + event.record() + with torch.cuda.stream(stream): + event.synchronize() + torch.sum(data * _weights, dim=-1, out=output_gray) + output_red.copy_(data[..., 0]) + event.record() + event.synchronize() \ No newline at end of file