Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/CodeGen_D3D12Compute_Dev.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -388,10 +388,12 @@ void CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::visit(const Call *op) {
} else if (op->name == "pow_f32" && can_prove(op->args[0] > 0)) {
// If we know pow(x, y) is called with x > 0, we can use HLSL's pow
// directly.
stream << "pow(" << print_expr(op->args[0]) << ", " << print_expr(op->args[1]) << ")";
Expr equiv = Call::make(op->type, "pow", op->args, Call::PureExtern);
equiv.accept(this);
} else if (op->is_intrinsic(Call::round)) {
// HLSL's round intrinsic has the correct semantics for our rounding.
print_assignment(op->type, "round(" + print_expr(op->args[0]) + ")");
Expr equiv = Call::make(op->type, "round", op->args, Call::PureExtern);
equiv.accept(this);
} else {
CodeGen_GPU_C::visit(op);
}
Expand Down
28 changes: 28 additions & 0 deletions test/correctness/math.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,33 @@ struct TestArgs {
}
};

// For D3D12Compute, we lower directly to HLSL's pow function if the base is provably positive.
// This test ensures that lowering is correct.
void test_pow_positive() {
printf("Testing pow(x, y) where x > 0\n");
TestArgs<float> args(256, 0.0f, 10.0f);
Func test_pow_positive("test_pow_positive");
Var x("x"), xi("xi");
test_pow_positive(x) = pow(1.5f, args.data(x));

Target target = get_jit_target_from_environment();
if (target.has_gpu_feature()) {
test_pow_positive.gpu_tile(x, xi, 16);
} else if (target.has_feature(Target::HVX)) {
test_pow_positive.hexagon();
}
Buffer<float> result = test_pow_positive.realize({args.data.extent(0)}, target);
for (int i = 0; i < args.data.extent(0); i++) {
float c_result = pow(1.5f, args.data(i));
if (!relatively_equal(c_result, result(i), target)) {
fprintf(stderr, "For pow(1.5f, %.20f) == %.20f from C and %.20f from %s.\n",
(double)args.data(i), (double)c_result, (double)result(i),
target.to_string().c_str());
num_errors++;
}
}
}

// Using macros to expand name as both a C function and an Expr fragment.
// It may well be possible to do this without macros, but that is left
// for another day.
Expand Down Expand Up @@ -299,6 +326,7 @@ int main(int argc, char **argv) {
call_1_float_types(trunc, 256, -25, 25);

call_2_float_types(pow, 256, -10.0, 10.0, -4.0f, 4.0f);
test_pow_positive();

const int8_t int8_min = std::numeric_limits<int8_t>::min();
const int16_t int16_min = std::numeric_limits<int16_t>::min();
Expand Down
Loading