Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/CodeGen_D3D12Compute_Dev.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -388,10 +388,12 @@ void CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::visit(const Call *op) {
} else if (op->name == "pow_f32" && can_prove(op->args[0] > 0)) {
// If we know pow(x, y) is called with x > 0, we can use HLSL's pow
// directly.
stream << "pow(" << print_expr(op->args[0]) << ", " << print_expr(op->args[1]) << ")";
Expr equiv = Call::make(op->type, "pow", op->args, Call::PureExtern);
equiv.accept(this);
} else if (op->is_intrinsic(Call::round)) {
// HLSL's round intrinsic has the correct semantics for our rounding.
print_assignment(op->type, "round(" + print_expr(op->args[0]) + ")");
Expr equiv = Call::make(op->type, "round", op->args, Call::PureExtern);
equiv.accept(this);
} else {
CodeGen_GPU_C::visit(op);
}
Expand Down
28 changes: 28 additions & 0 deletions test/correctness/math.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,33 @@ fun_2(uint32_t, uint32_t, absd, absd)
call_2(double, name, steps, start1, end1, start2, end2); \
} while (0)

// For D3D12Compute, we lower directly to HLSL's pow function if the base is provably positive.
// This test ensures that lowering is correct.
void test_pow_positive() {
printf("Testing pow(x, y) where x > 0\n");
TestArgs<float> args(256, 0.0f, 10.0f);
Func test_pow_positive("test_pow_positive");
Var x("x"), xi("xi");
test_pow_positive(x) = pow(1.5f, args.data(x));

Target target = get_jit_target_from_environment();
if (target.has_gpu_feature()) {
test_pow_positive.gpu_tile(x, xi, 16); //.vectorize(xi, 2);
} else if (target.has_feature(Target::HVX)) {
test_pow_positive.hexagon();
}
Buffer<float> result = test_pow_positive.realize({args.data.extent(0)}, target);
for (int i = 0; i < args.data.extent(0); i++) {
float c_result = pow(1.5f, args.data(i));
if (!relatively_equal(c_result, result(i), target)) {
fprintf(stderr, "For pow(1.5f, %.20f) == %.20f from C and %.20f from %s.\n",
(double)args.data(i), (double)c_result, (double)result(i),
target.to_string().c_str());
num_errors++;
}
}
}

} // namespace

int main(int argc, char **argv) {
Expand Down Expand Up @@ -299,6 +326,7 @@ int main(int argc, char **argv) {
call_1_float_types(trunc, 256, -25, 25);

call_2_float_types(pow, 256, -10.0, 10.0, -4.0f, 4.0f);
test_pow_positive();

const int8_t int8_min = std::numeric_limits<int8_t>::min();
const int16_t int16_min = std::numeric_limits<int16_t>::min();
Expand Down
Loading