Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion test/stablehlo/test_stablehlo_custom_call.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import expecttest
import sys
import re
import unittest
Expand All @@ -16,7 +17,7 @@
m = Library("my_custom_library", "DEF")


class StableHLOCustomCallExportTest(unittest.TestCase):
class StableHLOCustomCallExportTest(expecttest.TestCase):

def test_single_output(self):

Expand Down
51 changes: 51 additions & 0 deletions test/test_ops_error_message.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Callable
import expecttest
import os
import torch
Expand Down Expand Up @@ -357,6 +358,56 @@ def gen_test_fn(kernel_size=[2, 2, 2], stride=[], padding=[0]):
expect="""avg_pool3d(): expected argument padding [1, 2] (size: 2) to have size of 3."""
)

def _get_custom_call_properties(self, mode):
match mode:
case "tpu":
return (torch_xla._XLAC._xla_tpu_custom_call, "", [])
case "stablehlo":
return (torch_xla._XLAC._xla_custom_call, "custom_op_target",
[False, "", 0, {}])

self.fail(f"expected `mode` ({mode}) to be either of ['tpu', 'stablehlo'].")

def _gen_custom_call_no_input(self, mode):
lib_custom_call, payload, args = self._get_custom_call_properties(
mode) # type: ignore[attr-defined]
return lambda: lib_custom_call([], payload, [[1]], [torch.int8], *args)

def _gen_custom_call_output_properties_size_mismatch(self, mode):
lib_custom_call, payload, args = self._get_custom_call_properties(
mode) # type: ignore[attr-defined]
input = torch.rand(10, device=torch_xla.device())
return lambda: lib_custom_call(
(input,), payload, [[1], [1]], [torch.int8], *args)

def test_stablehlo_custom_call(self):

self.assertExpectedRaisesInline(
exc_type=RuntimeError,
callable=self._gen_custom_call_no_input("stablehlo"),
expect="""custom_call(custom_op_target): expected at least 1 input tensor."""
)

self.assertExpectedRaisesInline(
exc_type=RuntimeError,
callable=self._gen_custom_call_output_properties_size_mismatch(
"stablehlo"),
expect="""custom_call(custom_op_target): expected the given output shapes (size=2) to be of the same size as the given output dtypes (size=1)."""
)

def test_tpu_custom_call(self):

self.assertExpectedRaisesInline(
exc_type=RuntimeError,
callable=self._gen_custom_call_no_input("tpu"),
expect="""tpu_custom_call(): expected at least 1 input tensor.""")

self.assertExpectedRaisesInline(
exc_type=RuntimeError,
callable=self._gen_custom_call_output_properties_size_mismatch("tpu"),
expect="""tpu_custom_call(): expected the given output shapes (size=2) to be of the same size as the given output dtypes (size=1)."""
)


if __name__ == "__main__":
unittest.main()
48 changes: 18 additions & 30 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -347,21 +347,6 @@ std::vector<std::vector<int64_t>> CreateReduceGroups(const py::list& groups) {
return replica_groups;
}

std::vector<at::Tensor> TpuCustomCall(
const std::vector<at::Tensor>& inputs, const std::string& payload,
const std::vector<std::vector<int64_t>>& output_shapes,
const std::vector<py::object>& output_dtypes) {
std::vector<at::ScalarType> dtypes;
dtypes.reserve(output_dtypes.size());
for (auto& dtype : output_dtypes) {
dtypes.push_back(reinterpret_cast<THPDtype*>(dtype.ptr())->scalar_type);
}
XLA_ASSIGN_OR_THROW(std::vector<absl_nonnull XLATensorPtr> xla_inputs,
bridge::GetXlaTensors(inputs));
return bridge::AtenFromXlaTensors(tensor_methods::tpu_custom_call(
xla_inputs, payload, output_shapes, dtypes));
}

std::vector<std::vector<int>> ExtractXlaDotGeneralDimVectors(
const py::tuple& dimension_numbers) {
// Expect Python arg `dimension_numbers` to be
Expand Down Expand Up @@ -3116,30 +3101,33 @@ void InitXlaModuleBindings(py::module m) {
"_xla_custom_call",
[](const std::vector<at::Tensor>& inputs, const std::string& target,
const std::vector<std::vector<int64_t>>& output_shapes,
const std::vector<py::object>& output_dtypes, bool has_side_effect,
const std::vector<at::ScalarType>& output_dtypes, bool has_side_effect,
const std::string& backend_config, const int api_version,
const std::unordered_map<std::string, std::string>&
frontend_attributes) -> std::vector<at::Tensor> {
std::vector<at::ScalarType> dtypes;
dtypes.reserve(output_dtypes.size());
for (auto& dtype : output_dtypes) {
dtypes.push_back(
reinterpret_cast<THPDtype*>(dtype.ptr())->scalar_type);
}

XLA_ASSIGN_OR_THROW(std::vector<absl_nonnull XLATensorPtr> xla_inputs, bridge::GetXlaTensors(inputs));
auto xtensors = tensor_methods::custom_call(
xla_inputs, target,
output_shapes, dtypes, has_side_effect, backend_config,
api_version, frontend_attributes);
return bridge::AtenFromXlaTensors(std::move(xtensors));
XLA_ASSIGN_OR_THROW(std::vector<absl_nonnull XLATensorPtr> xla_inputs,
bridge::GetXlaTensors(inputs));
XLA_ASSIGN_OR_THROW(std::vector<absl_nonnull XLATensorPtr> xla_outputs,
tensor_methods::custom_call(
xla_inputs, target, output_shapes, output_dtypes,
has_side_effect, backend_config, api_version,
frontend_attributes));

return bridge::AtenFromXlaTensors(std::move(xla_outputs));
})
.def("_xla_tpu_custom_call",
[](const std::vector<at::Tensor>& inputs, const std::string& payload,
const std::vector<std::vector<int64_t>>& output_shapes,
const std::vector<py::object>& output_dtypes)
const std::vector<at::ScalarType>& output_dtypes)
-> std::vector<at::Tensor> {
return TpuCustomCall(inputs, payload, output_shapes, output_dtypes);

XLA_ASSIGN_OR_THROW(std::vector<absl_nonnull XLATensorPtr> xla_inputs,
bridge::GetXlaTensors(inputs));
XLA_ASSIGN_OR_THROW(std::vector<absl_nonnull XLATensorPtr> xla_outputs,
tensor_methods::tpu_custom_call(xla_inputs, payload, output_shapes, output_dtypes));

return bridge::AtenFromXlaTensors(std::move(xla_outputs));
})
.def("_xla_register_custom_call_target",
[](const std::string& fn_name, const py::capsule& function_ptr,
Expand Down
193 changes: 115 additions & 78 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -299,27 +299,6 @@ absl::StatusOr<PoolNdInputsOwner> FillAndCheckPoolNdInputs(
return PoolNdInputsOwner{kernel_size, stride, padding};
}

// Resizes and / or checks whether a list is of the given size. The list is only
// resized if its size is 1. If it's empty, it's replaced with the provided
// default first.
std::vector<int64_t> CheckIntList(absl::Span<const int64_t> list, size_t length,
const std::string& name,
std::vector<int64_t> def = {}) {
std::vector<int64_t> result;
if (list.empty()) {
result = std::move(def);
} else {
result = torch::lazy::ToVector<int64_t>(list);
}
if (result.size() == 1 && length > 1) {
result.resize(length, result[0]);
return result;
}
XLA_CHECK_EQ(result.size(), length)
<< "Invalid length for the '" << name << "' attribute";
return result;
}

// Returns a 1-D shape for batch norm weight or bias based on the input shape.
xla::Shape BatchNormFeaturesShape(const XLATensorPtr& input) {
xla::PrimitiveType input_element_type =
Expand Down Expand Up @@ -666,6 +645,92 @@ absl::Status CheckUniformRangeIsValid(double from, double to) {
return absl::OkStatus();
}

// This check is used for both `custom_call()` and `tpu_custom_call()`.
//
// The `target` parameter is `std::nullopt` whenever it's being called from
// a `tpu_custom_call()` context.
absl::Status CheckCustomCallNonEmptyInputs(
const std::vector<absl_nonnull XLATensorPtr>& inputs,
const std::optional<std::string>& target) {
if (inputs.empty()) {
std::string op = target.has_value()
? absl::StrCat("custom_call(", *target, ")")
: "tpu_custom_call()";
return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(
absl::StrCat(op, ": expected at least 1 input tensor.")));
}
return absl::OkStatus();
}

// This check is used for both `custom_call()` and `tpu_custom_call()`.
//
// The `target` parameter is `std::nullopt` whenever it's being called from
// a `tpu_custom_call()` context.
absl::Status CheckCustomCallOutputPropertiesSize(
const std::vector<std::vector<int64_t>>& output_shapes,
const std::vector<at::ScalarType>& output_dtypes,
const std::optional<std::string>& target) {
if (output_shapes.size() != output_dtypes.size()) {
std::string op = target.has_value()
? absl::StrCat("custom_call(", *target, ")")
: "tpu_custom_call()";
return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(absl::StrCat(
op, ": expected the given output shapes (size=", output_shapes.size(),
") to be of the same size as the given output dtypes (size=",
output_dtypes.size(), ").")));
}
return absl::OkStatus();
}

// This check is used for both `custom_call()` and `tpu_custom_call()`.
//
// The `target` parameter is `std::nullopt` whenever it's being called from
// a `tpu_custom_call()` context.
template <class F>
absl::StatusOr<std::vector<absl_nonnull XLATensorPtr>> CustomCallImpl(
const std::vector<absl_nonnull XLATensorPtr>& inputs,
const std::optional<std::string>& target,
const std::vector<std::vector<int64_t>>& output_shapes,
const std::vector<at::ScalarType>& output_dtypes, F&& make_node) {
XLA_RETURN_IF_ERROR(CheckCustomCallNonEmptyInputs(inputs, target));
XLA_RETURN_IF_ERROR(CheckCustomCallOutputPropertiesSize(
output_shapes, output_dtypes, target));

const auto& first = inputs.front();
auto device = first->GetDevice();
auto output_range = c10::irange(output_shapes.size());

// `values`: vector with Lazy IR of `inputs`.
std::vector<torch::lazy::Value> values(inputs.size());
std::transform(
inputs.begin(), inputs.end(), values.begin(),
[](const XLATensorPtr& tensor) { return tensor->GetIrValue(); });

// `output_xla_shapes`: `xla::Shape` instances created from `output_shapes`
// and `output_dtypes`.
std::vector<xla::Shape> output_xla_shapes(output_shapes.size());
std::transform(output_range.begin(), output_range.end(),
output_xla_shapes.begin(), [&](std::size_t i) {
return xla::ShapeUtil::MakeShape(
MakeXlaPrimitiveType(output_dtypes[i], &device),
output_shapes[i]);
});

auto node = make_node(values, output_xla_shapes);

// `outputs`: `XLATensorPtr` instances created from the `i`-th output of
// the `node` Lazy IR `Node`.
std::vector<XLATensorPtr> outputs(output_shapes.size());
std::transform(output_range.begin(), output_range.end(), outputs.begin(),
[&](std::size_t i) {
return first->CreateFrom(torch::lazy::Value(node, i),
output_dtypes[i],
/*delay_eager_execution=*/true);
});

return outputs;
}

} // namespace

//////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -886,40 +951,26 @@ std::pair<XLATensorPtr, torch::lazy::Value> collective_permute(
torch::lazy::Value(node, 1)};
}

std::vector<XLATensorPtr> custom_call(
const std::vector<XLATensorPtr>& inputs, const std::string& target,
absl::StatusOr<std::vector<absl_nonnull XLATensorPtr>> custom_call(
const std::vector<absl_nonnull XLATensorPtr>& inputs,
const std::string& target,
const std::vector<std::vector<int64_t>>& output_shapes,
const std::vector<at::ScalarType>& output_dtypes, bool has_side_effect,
const std::string& backend_config, const int api_version,
const std::unordered_map<std::string, std::string>& frontend_attributes) {
XLA_CHECK(inputs.size() > 0) << "inputs are empty";

std::vector<torch::lazy::Value> values;
values.reserve(inputs.size());
for (const auto& input : inputs) {
values.push_back(input->GetIrValue());
}

XLA_CHECK_EQ(output_shapes.size(), output_dtypes.size());
std::vector<xla::Shape> output_xla_shapes;
output_xla_shapes.reserve(output_shapes.size());
for (size_t i = 0; i < output_shapes.size(); ++i) {
output_xla_shapes.push_back(xla::ShapeUtil::MakeShape(
MakeXlaPrimitiveType(output_dtypes[i], &(inputs[0]->GetDevice())),
output_shapes[i]));
}

auto node = torch_xla::MakeNode<CustomCall>(
values, target, xla::ShapeUtil::MakeTupleShape(output_xla_shapes),
has_side_effect, backend_config, api_version, frontend_attributes);
XLA_ASSIGN_OR_RETURN(
std::vector<absl_nonnull XLATensorPtr> outputs,
CustomCallImpl(inputs, target, output_shapes, output_dtypes,
/* make_node= */
[&](const std::vector<torch::lazy::Value>& values,
const std::vector<xla::Shape>& output_xla_shapes) {
return torch_xla::MakeNode<CustomCall>(
values, target,
xla::ShapeUtil::MakeTupleShape(output_xla_shapes),
has_side_effect, backend_config, api_version,
frontend_attributes);
}));

std::vector<XLATensorPtr> outputs;
outputs.reserve(output_shapes.size());
for (size_t i = 0; i < output_shapes.size(); ++i) {
outputs.push_back(inputs[0]->CreateFrom(torch::lazy::Value(node, i),
output_dtypes[i],
/*delay_eager_execution=*/true));
}
XLAGraphExecutor* graph_executor = XLAGraphExecutor::Get();
if (graph_executor->UseEagerMode()) {
// Execute the HLO that will run the `customcall` and in one graph
Expand Down Expand Up @@ -954,37 +1005,23 @@ void custom_sharding_(
input->SetShardingSpec(*sharding_spec);
}

std::vector<XLATensorPtr> tpu_custom_call(
const std::vector<XLATensorPtr>& inputs, const std::string& payload,
absl::StatusOr<std::vector<absl_nonnull XLATensorPtr>> tpu_custom_call(
const std::vector<absl_nonnull XLATensorPtr>& inputs,
const std::string& payload,
const std::vector<std::vector<int64_t>>& output_shapes,
const std::vector<at::ScalarType>& output_dtypes) {
XLA_CHECK(inputs.size() > 0) << "inputs are empty";

std::vector<torch::lazy::Value> values;
values.reserve(inputs.size());
for (const auto& input : inputs) {
values.push_back(input->GetIrValue());
}

XLA_CHECK_EQ(output_shapes.size(), output_dtypes.size());
std::vector<xla::Shape> output_xla_shapes;
output_xla_shapes.reserve(output_shapes.size());
for (size_t i = 0; i < output_shapes.size(); ++i) {
output_xla_shapes.push_back(xla::ShapeUtil::MakeShape(
MakeXlaPrimitiveType(output_dtypes[i], &(inputs[0]->GetDevice())),
output_shapes[i]));
}

auto node = torch_xla::MakeNode<TpuCustomCall>(
values, xla::ShapeUtil::MakeTupleShape(output_xla_shapes), payload);
XLA_ASSIGN_OR_RETURN(
std::vector<absl_nonnull XLATensorPtr> outputs,
CustomCallImpl(
inputs, /* target= */ std::nullopt, output_shapes, output_dtypes,
/* make_node= */
[&](const std::vector<torch::lazy::Value>& values,
const std::vector<xla::Shape>& output_xla_shapes) {
return torch_xla::MakeNode<TpuCustomCall>(
values, xla::ShapeUtil::MakeTupleShape(output_xla_shapes),
payload);
}));

std::vector<XLATensorPtr> outputs;
outputs.reserve(output_shapes.size());
for (size_t i = 0; i < output_shapes.size(); ++i) {
outputs.push_back(inputs[0]->CreateFrom(torch::lazy::Value(node, i),
output_dtypes[i],
/*delay_eager_execution=*/true));
}
XLAGraphExecutor* graph_executor = XLAGraphExecutor::Get();
if (graph_executor->UseEagerMode()) {
// Execute the HLO that will run the `custom` and in one hlo
Expand Down
10 changes: 6 additions & 4 deletions torch_xla/csrc/tensor_methods.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,9 @@ std::pair<XLATensorPtr, torch::lazy::Value> collective_permute(
const XLATensorPtr& input, const torch::lazy::Value& token,
std::vector<std::pair<int64_t, int64_t>> source_target_pairs);

std::vector<XLATensorPtr> custom_call(
const std::vector<XLATensorPtr>& inputs, const std::string& target,
absl::StatusOr<std::vector<absl_nonnull XLATensorPtr>> custom_call(
const std::vector<absl_nonnull XLATensorPtr>& inputs,
const std::string& target,
const std::vector<std::vector<int64_t>>& output_shapes,
const std::vector<at::ScalarType>& output_dtypes, bool has_side_effect,
const std::string& backend_config, const int api_version,
Expand All @@ -104,8 +105,9 @@ void custom_sharding_(
const std::shared_ptr<XLATensor::ShardingSpec>& spec,
const CustomSharding::Type& type = CustomSharding::Type::kSharding);

std::vector<XLATensorPtr> tpu_custom_call(
const std::vector<XLATensorPtr>& inputs, const std::string& payload,
absl::StatusOr<std::vector<absl_nonnull XLATensorPtr>> tpu_custom_call(
const std::vector<absl_nonnull XLATensorPtr>& inputs,
const std::string& payload,
const std::vector<std::vector<int64_t>>& output_shapes,
const std::vector<at::ScalarType>& output_dtypes);

Expand Down