diff --git a/src/CodeGen_D3D12Compute_Dev.cpp b/src/CodeGen_D3D12Compute_Dev.cpp index ad4f6451f918..e40dd1b316ef 100644 --- a/src/CodeGen_D3D12Compute_Dev.cpp +++ b/src/CodeGen_D3D12Compute_Dev.cpp @@ -388,10 +388,12 @@ void CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::visit(const Call *op) { } else if (op->name == "pow_f32" && can_prove(op->args[0] > 0)) { // If we know pow(x, y) is called with x > 0, we can use HLSL's pow // directly. - stream << "pow(" << print_expr(op->args[0]) << ", " << print_expr(op->args[1]) << ")"; + Expr equiv = Call::make(op->type, "pow", op->args, Call::PureExtern); + equiv.accept(this); } else if (op->is_intrinsic(Call::round)) { // HLSL's round intrinsic has the correct semantics for our rounding. - print_assignment(op->type, "round(" + print_expr(op->args[0]) + ")"); + Expr equiv = Call::make(op->type, "round", op->args, Call::PureExtern); + equiv.accept(this); } else { CodeGen_GPU_C::visit(op); } diff --git a/test/correctness/math.cpp b/test/correctness/math.cpp index 68ff3c0e56e8..9d09f19c9193 100644 --- a/test/correctness/math.cpp +++ b/test/correctness/math.cpp @@ -118,6 +118,33 @@ struct TestArgs { } }; +// For D3D12Compute, we lower directly to HLSL's pow function if the base is provably positive. +// This test ensures that lowering is correct. +void test_pow_positive() { + printf("Testing pow(x, y) where x > 0\n"); + TestArgs args(256, 0.0f, 10.0f); + Func test_pow_positive("test_pow_positive"); + Var x("x"), xi("xi"); + test_pow_positive(x) = pow(1.5f, args.data(x)); + + Target target = get_jit_target_from_environment(); + if (target.has_gpu_feature()) { + test_pow_positive.gpu_tile(x, xi, 16); + } else if (target.has_feature(Target::HVX)) { + test_pow_positive.hexagon(); + } + Buffer result = test_pow_positive.realize({args.data.extent(0)}, target); + for (int i = 0; i < args.data.extent(0); i++) { + float c_result = pow(1.5f, args.data(i)); + if (!relatively_equal(c_result, result(i), target)) { + fprintf(stderr, "For pow(1.5f, %.20f) == %.20f from C and %.20f from %s.\n", + (double)args.data(i), (double)c_result, (double)result(i), + target.to_string().c_str()); + num_errors++; + } + } +} + // Using macros to expand name as both a C function and an Expr fragment. // It may well be possible to do this without macros, but that is left // for another day. @@ -299,6 +326,7 @@ int main(int argc, char **argv) { call_1_float_types(trunc, 256, -25, 25); call_2_float_types(pow, 256, -10.0, 10.0, -4.0f, 4.0f); + test_pow_positive(); const int8_t int8_min = std::numeric_limits::min(); const int16_t int16_min = std::numeric_limits::min();