From 07ca2391005471be6f1b8851b7a5fcd96cf91a66 Mon Sep 17 00:00:00 2001 From: Frank Liu Date: Tue, 5 Aug 2025 05:14:42 +0000 Subject: [PATCH 01/12] Add xla random generator. --- .github/scripts/run_tests.sh | 1 + BUILD | 7 ++- test/cpp/BUILD | 12 ++++ test/cpp/run_tests.sh | 1 + test/cpp/test_xla_generator.cpp | 103 +++++++++++++++++++++++++++++++ torch_xla/csrc/BUILD | 2 + torch_xla/csrc/xla_generator.cpp | 82 ++++++++++++++++++++++++ torch_xla/csrc/xla_generator.h | 56 +++++++++++++++++ 8 files changed, 261 insertions(+), 3 deletions(-) create mode 100644 test/cpp/test_xla_generator.cpp create mode 100644 torch_xla/csrc/xla_generator.cpp create mode 100644 torch_xla/csrc/xla_generator.h diff --git a/.github/scripts/run_tests.sh b/.github/scripts/run_tests.sh index d685cc40ee49..ccdc0b5e3d70 100755 --- a/.github/scripts/run_tests.sh +++ b/.github/scripts/run_tests.sh @@ -55,6 +55,7 @@ function run_torch_xla_cpp_tests() { "test_tensor" # disable test_xla_backend_intf since it is flaky on upstream #"test_xla_backend_intf" + "test_xla_generator" "test_xla_sharding" "test_runtime" "test_status_dont_show_cpp_stacktraces" diff --git a/BUILD b/BUILD index ee4fa07844ac..900dfa4bc3b2 100644 --- a/BUILD +++ b/BUILD @@ -72,15 +72,16 @@ test_suite( "//test/cpp:test_aten_xla_tensor_4", "//test/cpp:test_aten_xla_tensor_5", "//test/cpp:test_aten_xla_tensor_6", + "//test/cpp:test_debug_macros", "//test/cpp:test_ir", "//test/cpp:test_lazy", "//test/cpp:test_replication", - "//test/cpp:test_tensor", - "//test/cpp:test_xla_sharding", "//test/cpp:test_runtime", "//test/cpp:test_status_dont_show_cpp_stacktraces", "//test/cpp:test_status_show_cpp_stacktraces", - "//test/cpp:test_debug_macros", + "//test/cpp:test_tensor", + "//test/cpp:test_xla_generator", + "//test/cpp:test_xla_sharding", "//torch_xla/csrc/runtime:pjrt_computation_client_test", # "//torch_xla/csrc/runtime:ifrt_computation_client_test", ], diff --git a/test/cpp/BUILD b/test/cpp/BUILD index e752eab4f670..dab678af767d 100644 --- a/test/cpp/BUILD +++ b/test/cpp/BUILD @@ -202,3 +202,15 @@ ptxla_cc_test( "@com_google_googletest//:gtest_main", ], ) + +ptxla_cc_test( + name = "test_xla_generator", + srcs = ["test_xla_generator.cpp"], + deps = [ + ":cpp_test_util", + ":torch_xla_test", + "//torch_xla/csrc:tensor", + "//torch_xla/csrc:aten_cuda_functions", + "@com_google_googletest//:gtest_main", + ], +) \ No newline at end of file diff --git a/test/cpp/run_tests.sh b/test/cpp/run_tests.sh index 8c3fea6bcdc8..2da0ccb55699 100755 --- a/test/cpp/run_tests.sh +++ b/test/cpp/run_tests.sh @@ -100,6 +100,7 @@ if [[ "$RUN_CPP_TESTS" == "cpp_tests" ]]; then # disable test_xla_backend_intf since it is flaky on upstream #"test_xla_backend_intf" "test_xla_sharding" + "test_xla_generator" "test_runtime" "test_status_dont_show_cpp_stacktraces" "test_status_show_cpp_stacktraces" diff --git a/test/cpp/test_xla_generator.cpp b/test/cpp/test_xla_generator.cpp new file mode 100644 index 000000000000..687f7d4eea93 --- /dev/null +++ b/test/cpp/test_xla_generator.cpp @@ -0,0 +1,103 @@ +#include +#include +#include "test/cpp/torch_xla_test.h" +#include "torch_xla/csrc/xla_generator.h" + +namespace torch_xla { +namespace cpp_test { + +// Test fixture for XLAGenerator tests +class XLAGeneratorTest : public ::torch_xla::cpp_test::TorchXlaTest { + protected: + void SetUp() { + // Create a generator for XLA device 0 + gen_ = at::make_generator(0); + } + + at::Generator gen_; +}; + +TEST_F(XLAGeneratorTest, Constructor) { + // Check that the generator was created for the correct device + ASSERT_EQ(gen_.device().type(), at::DeviceType::XLA); + ASSERT_EQ(gen_.device().index(), 0); + + // Check that the initial seed is 0 + ASSERT_EQ(gen_.current_seed(), 0); +} + +TEST_F(XLAGeneratorTest, Seed) { + // Test setting and getting the current seed + uint64_t seed_val = 12345; + gen_.set_current_seed(seed_val); + ASSERT_EQ(gen_.current_seed(), seed_val); + + // Test the seed() method, which should set a non-deterministic seed + uint64_t old_seed = gen_.current_seed(); + uint64_t new_seed = gen_.seed(); + // The new seed should be different from the old one and set as the current seed + ASSERT_NE(new_seed, old_seed); + ASSERT_EQ(gen_.current_seed(), new_seed); +} + +TEST_F(XLAGeneratorTest, GetAndSetState) { + uint64_t seed_val = 98765; + uint64_t offset_val = 0; + + // Set seed and offset on the original generator + gen_.set_current_seed(seed_val); + gen_.set_offset(offset_val); + + // Get the state from the original generator + at::Tensor state_tensor = gen_.get_state(); + + // Create a new generator + auto new_gen = at::make_generator(1); + ASSERT_NE(new_gen.current_seed(), seed_val); + + // Set the state of the new generator + new_gen.set_state(state_tensor); + + // Verify the state of the new generator + ASSERT_EQ(new_gen.current_seed(), seed_val); + ASSERT_EQ(new_gen.get_offset(), offset_val); +} + +TEST_F(XLAGeneratorTest, SetStateValidation) { + // Test that set_state throws with incorrect tensor properties + auto new_gen = at::make_generator(0); + + // Incorrect size + auto wrong_size_tensor = at::empty({10}, at::kByte); + EXPECT_THROW(new_gen.set_state(wrong_size_tensor), c10::Error); + + // Incorrect dtype + auto wrong_dtype_tensor = at::empty({16}, at::kInt); + EXPECT_THROW(new_gen.set_state(wrong_dtype_tensor), c10::Error); +} + +TEST_F(XLAGeneratorTest, Clone) { + uint64_t seed_val = 1; + uint64_t offset_val = 0; + + // Set state on the original generator + gen_.set_current_seed(seed_val); + gen_.set_offset(offset_val); + + // Clone the generator + auto cloned_gen = gen_.clone(); + + // Verify that the cloned generator has the same state but is a different object + ASSERT_NE(std::addressof(cloned_gen), std::addressof(gen_)); + ASSERT_EQ(cloned_gen.device(), gen_.device()); + ASSERT_EQ(cloned_gen.current_seed(), gen_.current_seed()); + ASSERT_EQ(cloned_gen.get_offset(), offset_val); + + // Modify the original generator's seed and check that the clone is unaffected + gen_.set_current_seed(9999); + ASSERT_EQ(cloned_gen.current_seed(), seed_val); + ASSERT_NE(cloned_gen.current_seed(), gen_.current_seed()); +} + +} // namespace cpp_test +} // namespace torch_xla \ No newline at end of file diff --git a/torch_xla/csrc/BUILD b/torch_xla/csrc/BUILD index 31ab65dbbcaf..a871feaa346e 100644 --- a/torch_xla/csrc/BUILD +++ b/torch_xla/csrc/BUILD @@ -64,6 +64,7 @@ ptxla_cc_library( "torch_util.cpp", "view.cpp", "xla_backend_impl.cpp", + "xla_generator.cpp", "xla_graph_executor.cpp", "xla_lower_util.cpp", "xla_op_builder.cpp", @@ -107,6 +108,7 @@ ptxla_cc_library( "torch_util.h", "view.h", "xla_backend_impl.h", + "xla_generator.h", "xla_graph_executor.h", "xla_lower_util.h", "xla_op_builder.h", diff --git a/torch_xla/csrc/xla_generator.cpp b/torch_xla/csrc/xla_generator.cpp new file mode 100644 index 000000000000..492f086bdceb --- /dev/null +++ b/torch_xla/csrc/xla_generator.cpp @@ -0,0 +1,82 @@ +#include "xla_generator.h" +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { + +XLAGeneratorImpl::XLAGeneratorImpl(DeviceIndex device_index) + : c10::GeneratorImpl{Device(DeviceType::XLA, device_index), DispatchKeySet(c10::DispatchKey::XLA)} { + state_ = c10::make_intrusive(); +} + +XLAGeneratorImpl::XLAGeneratorImpl(DeviceIndex device_index, c10::intrusive_ptr state) + : c10::GeneratorImpl{Device(DeviceType::XLA, device_index), DispatchKeySet(c10::DispatchKey::XLA)}, state_(std::move(state)) {} + +DeviceType XLAGeneratorImpl::device_type() { + return DeviceType::XLA; +} + +std::shared_ptr XLAGeneratorImpl::clone() const { + return std::shared_ptr(clone_impl()); +} + +XLAGeneratorImpl* XLAGeneratorImpl::clone_impl() const { + return new XLAGeneratorImpl(device_.index(), state_->clone()); +} + +void XLAGeneratorImpl::set_current_seed(uint64_t seed) { + state_->seed_ = seed; +} + +uint64_t XLAGeneratorImpl::current_seed() const { + return state_->seed_; +} + +uint64_t XLAGeneratorImpl::seed() { + uint64_t random = c10::detail::getNonDeterministicRandom(true); + set_current_seed(random); + return random; +} + +void XLAGeneratorImpl::set_offset(uint64_t offset) { + state_->offset_ = offset; +} + +uint64_t XLAGeneratorImpl::get_offset() const { + return state_->offset_; +} + +/* Serialize the generator state into a CPU tensor. */ +c10::intrusive_ptr XLAGeneratorImpl::get_state() const { + static const size_t seed_size = sizeof(uint64_t); + static const size_t offset_size = sizeof(uint64_t); + static const size_t total_size = seed_size + offset_size; + + auto state_tensor = at::empty({(int64_t)total_size}, at::TensorOptions().dtype(at::kByte).device(at::kCPU)); + uint8_t* data_ptr = state_tensor.data_ptr(); + memcpy(data_ptr, &state_->seed_, seed_size); + memcpy(data_ptr + seed_size, &state_->offset_, offset_size); + return state_tensor.getIntrusivePtr(); +} + +void XLAGeneratorImpl::set_state(const c10::TensorImpl& new_state) { + static const size_t seed_size = sizeof(uint64_t); + static const size_t offset_size = sizeof(uint64_t); + static const size_t total_size = seed_size + offset_size; + + TORCH_CHECK(new_state.numel() == total_size, "The given state must be a byte tensor of size ", total_size, ", but was size ", new_state.numel()); + TORCH_CHECK(new_state.dtype() == at::kByte, "The given state must be a byte tensor, but was ", new_state.dtype()); + TORCH_CHECK(new_state.is_cpu(), "The given state must be a CPU tensor"); + + auto new_rng_state = new_state.data_dtype_initialized(); + memcpy(&state_->seed_, new_rng_state, seed_size); + memcpy(&state_->offset_, new_rng_state + seed_size, offset_size); +} + +} // namespace at diff --git a/torch_xla/csrc/xla_generator.h b/torch_xla/csrc/xla_generator.h new file mode 100644 index 000000000000..b8b7dc46e9ac --- /dev/null +++ b/torch_xla/csrc/xla_generator.h @@ -0,0 +1,56 @@ +#pragma once + +#include +#include +#include + +#include + +namespace at { + +// Holds the actual state variables for the XLA generator. +struct XLAGeneratorState : c10::intrusive_ptr_target { + uint64_t seed_ = 0; + uint64_t offset_ = 0; + + // Constructor + XLAGeneratorState(uint64_t seed = 0, uint64_t offset = 0) + : seed_(seed), offset_(offset) {} + + // Cloning method + c10::intrusive_ptr clone() { + return c10::make_intrusive(seed_, offset_); + } +}; + +struct TORCH_API XLAGeneratorImpl : public c10::GeneratorImpl { + // Constructors + XLAGeneratorImpl(DeviceIndex device_index = -1); + XLAGeneratorImpl(DeviceIndex device_index, c10::intrusive_ptr state); + ~XLAGeneratorImpl() override = default; + + // Cloning support + std::shared_ptr clone() const; + + // --- Core Virtual Methods to Override --- + void set_current_seed(uint64_t seed) override; + uint64_t current_seed() const override; + uint64_t seed() override; + void set_offset(uint64_t offset) override; + uint64_t get_offset() const override; + c10::intrusive_ptr get_state() const override; + void set_state(const c10::TensorImpl& new_state) override; + + // --- Additional Methods --- + static c10::DeviceType device_type(); + + private: + // Private clone implementation + XLAGeneratorImpl* clone_impl() const override; + + // The actual state is held in a separate, cloneable object. + c10::intrusive_ptr state_; + +}; + +} // namespace at \ No newline at end of file From 1d67a02d76f9e3931d5a04d3a25b97a0bb13c7d4 Mon Sep 17 00:00:00 2001 From: Frank Liu Date: Wed, 13 Aug 2025 00:13:27 +0000 Subject: [PATCH 02/12] format cpp files --- test/cpp/test_xla_generator.cpp | 131 ++++++++++++++++--------------- torch_xla/csrc/xla_generator.cpp | 52 ++++++------ torch_xla/csrc/xla_generator.h | 4 +- 3 files changed, 96 insertions(+), 91 deletions(-) diff --git a/test/cpp/test_xla_generator.cpp b/test/cpp/test_xla_generator.cpp index 687f7d4eea93..d45991f72d39 100644 --- a/test/cpp/test_xla_generator.cpp +++ b/test/cpp/test_xla_generator.cpp @@ -1,5 +1,6 @@ #include #include + #include "test/cpp/torch_xla_test.h" #include "torch_xla/csrc/xla_generator.h" @@ -9,94 +10,96 @@ namespace cpp_test { // Test fixture for XLAGenerator tests class XLAGeneratorTest : public ::torch_xla::cpp_test::TorchXlaTest { protected: - void SetUp() { - // Create a generator for XLA device 0 - gen_ = at::make_generator(0); - } + void SetUp() { + // Create a generator for XLA device 0 + gen_ = at::make_generator(0); + } - at::Generator gen_; + at::Generator gen_; }; TEST_F(XLAGeneratorTest, Constructor) { - // Check that the generator was created for the correct device - ASSERT_EQ(gen_.device().type(), at::DeviceType::XLA); - ASSERT_EQ(gen_.device().index(), 0); + // Check that the generator was created for the correct device + ASSERT_EQ(gen_.device().type(), at::DeviceType::XLA); + ASSERT_EQ(gen_.device().index(), 0); - // Check that the initial seed is 0 - ASSERT_EQ(gen_.current_seed(), 0); + // Check that the initial seed is 0 + ASSERT_EQ(gen_.current_seed(), 0); } TEST_F(XLAGeneratorTest, Seed) { - // Test setting and getting the current seed - uint64_t seed_val = 12345; - gen_.set_current_seed(seed_val); - ASSERT_EQ(gen_.current_seed(), seed_val); - - // Test the seed() method, which should set a non-deterministic seed - uint64_t old_seed = gen_.current_seed(); - uint64_t new_seed = gen_.seed(); - // The new seed should be different from the old one and set as the current seed - ASSERT_NE(new_seed, old_seed); - ASSERT_EQ(gen_.current_seed(), new_seed); + // Test setting and getting the current seed + uint64_t seed_val = 12345; + gen_.set_current_seed(seed_val); + ASSERT_EQ(gen_.current_seed(), seed_val); + + // Test the seed() method, which should set a non-deterministic seed + uint64_t old_seed = gen_.current_seed(); + uint64_t new_seed = gen_.seed(); + // The new seed should be different from the old one and set as the current + // seed + ASSERT_NE(new_seed, old_seed); + ASSERT_EQ(gen_.current_seed(), new_seed); } TEST_F(XLAGeneratorTest, GetAndSetState) { - uint64_t seed_val = 98765; - uint64_t offset_val = 0; + uint64_t seed_val = 98765; + uint64_t offset_val = 0; - // Set seed and offset on the original generator - gen_.set_current_seed(seed_val); - gen_.set_offset(offset_val); + // Set seed and offset on the original generator + gen_.set_current_seed(seed_val); + gen_.set_offset(offset_val); - // Get the state from the original generator - at::Tensor state_tensor = gen_.get_state(); + // Get the state from the original generator + at::Tensor state_tensor = gen_.get_state(); - // Create a new generator - auto new_gen = at::make_generator(1); - ASSERT_NE(new_gen.current_seed(), seed_val); + // Create a new generator + auto new_gen = at::make_generator(1); + ASSERT_NE(new_gen.current_seed(), seed_val); - // Set the state of the new generator - new_gen.set_state(state_tensor); + // Set the state of the new generator + new_gen.set_state(state_tensor); - // Verify the state of the new generator - ASSERT_EQ(new_gen.current_seed(), seed_val); - ASSERT_EQ(new_gen.get_offset(), offset_val); + // Verify the state of the new generator + ASSERT_EQ(new_gen.current_seed(), seed_val); + ASSERT_EQ(new_gen.get_offset(), offset_val); } TEST_F(XLAGeneratorTest, SetStateValidation) { - // Test that set_state throws with incorrect tensor properties - auto new_gen = at::make_generator(0); + // Test that set_state throws with incorrect tensor properties + auto new_gen = at::make_generator(0); - // Incorrect size - auto wrong_size_tensor = at::empty({10}, at::kByte); - EXPECT_THROW(new_gen.set_state(wrong_size_tensor), c10::Error); + // Incorrect size + auto wrong_size_tensor = at::empty({10}, at::kByte); + EXPECT_THROW(new_gen.set_state(wrong_size_tensor), c10::Error); - // Incorrect dtype - auto wrong_dtype_tensor = at::empty({16}, at::kInt); - EXPECT_THROW(new_gen.set_state(wrong_dtype_tensor), c10::Error); + // Incorrect dtype + auto wrong_dtype_tensor = at::empty({16}, at::kInt); + EXPECT_THROW(new_gen.set_state(wrong_dtype_tensor), c10::Error); } TEST_F(XLAGeneratorTest, Clone) { - uint64_t seed_val = 1; - uint64_t offset_val = 0; - - // Set state on the original generator - gen_.set_current_seed(seed_val); - gen_.set_offset(offset_val); - - // Clone the generator - auto cloned_gen = gen_.clone(); - - // Verify that the cloned generator has the same state but is a different object - ASSERT_NE(std::addressof(cloned_gen), std::addressof(gen_)); - ASSERT_EQ(cloned_gen.device(), gen_.device()); - ASSERT_EQ(cloned_gen.current_seed(), gen_.current_seed()); - ASSERT_EQ(cloned_gen.get_offset(), offset_val); - - // Modify the original generator's seed and check that the clone is unaffected - gen_.set_current_seed(9999); - ASSERT_EQ(cloned_gen.current_seed(), seed_val); - ASSERT_NE(cloned_gen.current_seed(), gen_.current_seed()); + uint64_t seed_val = 1; + uint64_t offset_val = 0; + + // Set state on the original generator + gen_.set_current_seed(seed_val); + gen_.set_offset(offset_val); + + // Clone the generator + auto cloned_gen = gen_.clone(); + + // Verify that the cloned generator has the same state but is a different + // object + ASSERT_NE(std::addressof(cloned_gen), std::addressof(gen_)); + ASSERT_EQ(cloned_gen.device(), gen_.device()); + ASSERT_EQ(cloned_gen.current_seed(), gen_.current_seed()); + ASSERT_EQ(cloned_gen.get_offset(), offset_val); + + // Modify the original generator's seed and check that the clone is unaffected + gen_.set_current_seed(9999); + ASSERT_EQ(cloned_gen.current_seed(), seed_val); + ASSERT_NE(cloned_gen.current_seed(), gen_.current_seed()); } } // namespace cpp_test diff --git a/torch_xla/csrc/xla_generator.cpp b/torch_xla/csrc/xla_generator.cpp index 492f086bdceb..5d0a7c15866b 100644 --- a/torch_xla/csrc/xla_generator.cpp +++ b/torch_xla/csrc/xla_generator.cpp @@ -1,26 +1,30 @@ #include "xla_generator.h" + +#include #include #include -#include -#include -#include #include +#include #include +#include + #include namespace at { XLAGeneratorImpl::XLAGeneratorImpl(DeviceIndex device_index) - : c10::GeneratorImpl{Device(DeviceType::XLA, device_index), DispatchKeySet(c10::DispatchKey::XLA)} { + : c10::GeneratorImpl{Device(DeviceType::XLA, device_index), + DispatchKeySet(c10::DispatchKey::XLA)} { state_ = c10::make_intrusive(); } -XLAGeneratorImpl::XLAGeneratorImpl(DeviceIndex device_index, c10::intrusive_ptr state) - : c10::GeneratorImpl{Device(DeviceType::XLA, device_index), DispatchKeySet(c10::DispatchKey::XLA)}, state_(std::move(state)) {} +XLAGeneratorImpl::XLAGeneratorImpl(DeviceIndex device_index, + c10::intrusive_ptr state) + : c10::GeneratorImpl{Device(DeviceType::XLA, device_index), + DispatchKeySet(c10::DispatchKey::XLA)}, + state_(std::move(state)) {} -DeviceType XLAGeneratorImpl::device_type() { - return DeviceType::XLA; -} +DeviceType XLAGeneratorImpl::device_type() { return DeviceType::XLA; } std::shared_ptr XLAGeneratorImpl::clone() const { return std::shared_ptr(clone_impl()); @@ -30,13 +34,9 @@ XLAGeneratorImpl* XLAGeneratorImpl::clone_impl() const { return new XLAGeneratorImpl(device_.index(), state_->clone()); } -void XLAGeneratorImpl::set_current_seed(uint64_t seed) { - state_->seed_ = seed; -} +void XLAGeneratorImpl::set_current_seed(uint64_t seed) { state_->seed_ = seed; } -uint64_t XLAGeneratorImpl::current_seed() const { - return state_->seed_; -} +uint64_t XLAGeneratorImpl::current_seed() const { return state_->seed_; } uint64_t XLAGeneratorImpl::seed() { uint64_t random = c10::detail::getNonDeterministicRandom(true); @@ -44,13 +44,9 @@ uint64_t XLAGeneratorImpl::seed() { return random; } -void XLAGeneratorImpl::set_offset(uint64_t offset) { - state_->offset_ = offset; -} +void XLAGeneratorImpl::set_offset(uint64_t offset) { state_->offset_ = offset; } -uint64_t XLAGeneratorImpl::get_offset() const { - return state_->offset_; -} +uint64_t XLAGeneratorImpl::get_offset() const { return state_->offset_; } /* Serialize the generator state into a CPU tensor. */ c10::intrusive_ptr XLAGeneratorImpl::get_state() const { @@ -58,7 +54,9 @@ c10::intrusive_ptr XLAGeneratorImpl::get_state() const { static const size_t offset_size = sizeof(uint64_t); static const size_t total_size = seed_size + offset_size; - auto state_tensor = at::empty({(int64_t)total_size}, at::TensorOptions().dtype(at::kByte).device(at::kCPU)); + auto state_tensor = + at::empty({(int64_t)total_size}, + at::TensorOptions().dtype(at::kByte).device(at::kCPU)); uint8_t* data_ptr = state_tensor.data_ptr(); memcpy(data_ptr, &state_->seed_, seed_size); memcpy(data_ptr + seed_size, &state_->offset_, offset_size); @@ -70,8 +68,12 @@ void XLAGeneratorImpl::set_state(const c10::TensorImpl& new_state) { static const size_t offset_size = sizeof(uint64_t); static const size_t total_size = seed_size + offset_size; - TORCH_CHECK(new_state.numel() == total_size, "The given state must be a byte tensor of size ", total_size, ", but was size ", new_state.numel()); - TORCH_CHECK(new_state.dtype() == at::kByte, "The given state must be a byte tensor, but was ", new_state.dtype()); + TORCH_CHECK(new_state.numel() == total_size, + "The given state must be a byte tensor of size ", total_size, + ", but was size ", new_state.numel()); + TORCH_CHECK(new_state.dtype() == at::kByte, + "The given state must be a byte tensor, but was ", + new_state.dtype()); TORCH_CHECK(new_state.is_cpu(), "The given state must be a CPU tensor"); auto new_rng_state = new_state.data_dtype_initialized(); @@ -79,4 +81,4 @@ void XLAGeneratorImpl::set_state(const c10::TensorImpl& new_state) { memcpy(&state_->offset_, new_rng_state + seed_size, offset_size); } -} // namespace at +} // namespace at diff --git a/torch_xla/csrc/xla_generator.h b/torch_xla/csrc/xla_generator.h index b8b7dc46e9ac..330d32861200 100644 --- a/torch_xla/csrc/xla_generator.h +++ b/torch_xla/csrc/xla_generator.h @@ -26,7 +26,8 @@ struct XLAGeneratorState : c10::intrusive_ptr_target { struct TORCH_API XLAGeneratorImpl : public c10::GeneratorImpl { // Constructors XLAGeneratorImpl(DeviceIndex device_index = -1); - XLAGeneratorImpl(DeviceIndex device_index, c10::intrusive_ptr state); + XLAGeneratorImpl(DeviceIndex device_index, + c10::intrusive_ptr state); ~XLAGeneratorImpl() override = default; // Cloning support @@ -50,7 +51,6 @@ struct TORCH_API XLAGeneratorImpl : public c10::GeneratorImpl { // The actual state is held in a separate, cloneable object. c10::intrusive_ptr state_; - }; } // namespace at \ No newline at end of file From 79ff99bce64b34b0e7d6d52b4d1998c17cd44567 Mon Sep 17 00:00:00 2001 From: Frank Liu Date: Mon, 20 Oct 2025 03:15:34 +0000 Subject: [PATCH 03/12] Add helper functions `getDefaultXLAGenerator` and `createXLAGenerator` to XLA random number generator --- torch_xla/csrc/xla_generator.cpp | 89 ++++++++++++++++++++++++++++++++ torch_xla/csrc/xla_generator.h | 11 +++- 2 files changed, 99 insertions(+), 1 deletion(-) diff --git a/torch_xla/csrc/xla_generator.cpp b/torch_xla/csrc/xla_generator.cpp index 5d0a7c15866b..5102be5df4ee 100644 --- a/torch_xla/csrc/xla_generator.cpp +++ b/torch_xla/csrc/xla_generator.cpp @@ -6,9 +6,98 @@ #include #include #include +#include #include +#include + +// XLA headers +#include "torch_xla/csrc/runtime/computation_client.h" +#include "torch_xla/csrc/aten_xla_bridge.h" #include +#include +#include + +namespace at { + +namespace detail { + +namespace { + +// Total number of XLA devices in the system. +static int64_t num_xla_devices; + +// Ensures default_gens_xla is initialized once. +static std::deque xla_gens_init_flag; + +// Default, global XLA generators, one per XLA device. +static std::vector default_gens_xla; + +/* + * Populates the global variables related to XLA generators + * Warning: this function must only be called once! + */ +static void initXLAGenVector() { + // Ensures we only call deviceCount only once. + static bool num_xla_device_init_flag [[maybe_unused]] = []() { + // Get local num of XLA devices + auto maybe_client = torch_xla::runtime::GetComputationClient(); + if (!maybe_client.ok()) { + // If runtime client initialization failed, default to 1 device + num_xla_devices = 1; + } else { + auto* client = maybe_client.value(); + num_xla_devices = static_cast(client->GetNumDevices()); + } + xla_gens_init_flag.resize(num_xla_devices); + default_gens_xla.resize(num_xla_devices); + return true; + }(); +} + +} // anonymous namespace + +/** + * PyTorch maintains a collection of default generators that get + * initialized once. The purpose of these default generators is to + * maintain a global running state of the pseudo random number generation, + * when a user does not explicitly mention any generator. + * getDefaultXLAGenerator gets the default generator for a particular + * XLA device. + */ +const at::Generator& getDefaultXLAGenerator(c10::DeviceIndex device_index) { + initXLAGenVector(); + c10::DeviceIndex idx = device_index; + if (idx == -1) { + idx = 0; // Default to device 0 for XLA + } else { + TORCH_CHECK(idx >= 0 && idx < num_xla_devices); + } + c10::call_once(xla_gens_init_flag[idx], [&] { + default_gens_xla[idx] = at::make_generator(idx); + default_gens_xla[idx].seed(); + }); + return default_gens_xla[idx]; +} + +/** + * Utility to create a XLAGeneratorImpl. Returns a shared_ptr + */ +at::Generator createXLAGenerator(c10::DeviceIndex device_index) { + initXLAGenVector(); + c10::DeviceIndex idx = device_index; + if (idx == -1) { + idx = torch_xla::bridge::GetCurrentAtenDevice().index(); // Use current XLA device + } + TORCH_CHECK(idx >= 0 && idx < num_xla_devices, "The device_index is invalid."); + auto gen = at::make_generator(idx); + auto xla_gen = at::check_generator(gen); + xla_gen->set_current_seed(c10::default_rng_seed_val); + return gen; +} + +} // namespace detail +} // namespace at namespace at { diff --git a/torch_xla/csrc/xla_generator.h b/torch_xla/csrc/xla_generator.h index 330d32861200..62621f7c37c9 100644 --- a/torch_xla/csrc/xla_generator.h +++ b/torch_xla/csrc/xla_generator.h @@ -2,6 +2,8 @@ #include #include +#include +#include #include #include @@ -53,4 +55,11 @@ struct TORCH_API XLAGeneratorImpl : public c10::GeneratorImpl { c10::intrusive_ptr state_; }; -} // namespace at \ No newline at end of file +namespace detail { + +const at::Generator& getDefaultXLAGenerator(c10::DeviceIndex device_index = -1); +at::Generator createXLAGenerator(c10::DeviceIndex device_index = -1); + +} // namespace detail + +} // namespace at From 79d4b4236aebd2f4b7624237015f30ca4be4b15f Mon Sep 17 00:00:00 2001 From: Frank Liu Date: Mon, 20 Oct 2025 03:17:51 +0000 Subject: [PATCH 04/12] implement `XLAHooks` and register it to PyTorch when loaded. --- torch_xla/csrc/BUILD | 19 +++++++ torch_xla/csrc/xla_hooks.cpp | 99 ++++++++++++++++++++++++++++++++++++ torch_xla/csrc/xla_hooks.h | 40 +++++++++++++++ 3 files changed, 158 insertions(+) create mode 100644 torch_xla/csrc/xla_hooks.cpp create mode 100644 torch_xla/csrc/xla_hooks.h diff --git a/torch_xla/csrc/BUILD b/torch_xla/csrc/BUILD index 8132ae733160..9c8455be184a 100644 --- a/torch_xla/csrc/BUILD +++ b/torch_xla/csrc/BUILD @@ -270,6 +270,7 @@ ptxla_cc_library( ":status", ":tensor", ":version", + ":xla_hooks", "//torch_xla/csrc/runtime", "//torch_xla/csrc/runtime:pjrt_computation_client", "//torch_xla/csrc/runtime:metrics", @@ -374,3 +375,21 @@ cc_library( "@com_google_absl//absl/status:statusor", ], ) + +ptxla_cc_library( + name = "xla_hooks", + srcs = [ + "xla_hooks.cpp", + ], + hdrs = [ + "xla_hooks.h", + ], + deps = [ + "//torch_xla/csrc:device", + "//torch_xla/csrc:tensor", + "//torch_xla/csrc/runtime:computation_client", + "//torch_xla/csrc/runtime", + "//torch_xla/csrc/runtime:xla_util", + ], +) + diff --git a/torch_xla/csrc/xla_hooks.cpp b/torch_xla/csrc/xla_hooks.cpp new file mode 100644 index 000000000000..257dc7677fef --- /dev/null +++ b/torch_xla/csrc/xla_hooks.cpp @@ -0,0 +1,99 @@ +#include "xla_hooks.h" + +#include +#include + +// PyTorch integration headers +#include +#include +#include +#include +#include +#include +#include + +// XLA headers +#include "xla_generator.h" +#include "xla_backend_impl.h" +#include "torch_xla/csrc/aten_xla_bridge.h" +#include "torch_xla/csrc/runtime/computation_client.h" +#include "torch_xla/csrc/runtime/debug_macros.h" +#include "torch_xla/csrc/runtime/runtime.h" + + +namespace torch_xla::detail { + +void XLAHooks::init() const { + C10_LOG_API_USAGE_ONCE("aten.init.xla"); + + // Initialize XLA backend - this registers XLA functions and sets up + // the backend infrastructure + torch_xla::InitXlaBackend(); +} + +bool XLAHooks::hasXLA() const { + return isAvailable(); +} + +bool XLAHooks::isAvailable() const { + try { + return deviceCount() > 0; + } catch (...) { + // If device enumeration fails, XLA is not available + return false; + } +} + +std::string XLAHooks::showConfig() const { + std::ostringstream oss; + oss << "XLA Backend Configuration:\n"; + oss << " - XLA devices available: " << deviceCount() << "\n"; + return oss.str(); +} + +c10::DeviceIndex XLAHooks::deviceCount() const { + auto maybe_client = torch_xla::runtime::GetComputationClient(); + if (!maybe_client.ok()) { + // If runtime client initialization failed, return 0 devices + return 0; + } + + auto* client = maybe_client.value(); + return static_cast(client->GetNumDevices()); +} + +c10::DeviceIndex XLAHooks::getCurrentDevice() const { + return bridge::GetCurrentAtenDevice().index(); +} + +bool XLAHooks::hasPrimaryContext(c10::DeviceIndex device_index) const { + TORCH_CHECK(false, "hasPrimaryContext is not implemented."); +} + +bool XLAHooks::isPinnedPtr(const void* data) const { + TORCH_CHECK(false, "isPinnedPtr is not implemented."); +} + +c10::Allocator* XLAHooks::getPinnedMemoryAllocator() const { + TORCH_CHECK(false, "getPinnedMemoryAllocator is not implemented."); +} + +c10::Device XLAHooks::getDeviceFromPtr(void* data) const { + TORCH_CHECK(false, "getDeviceFromPtr is not implemented."); +} + +const at::Generator& XLAHooks::getDefaultGenerator(c10::DeviceIndex device_index) const { + return at::detail::getDefaultXLAGenerator(device_index); +} + +at::Generator XLAHooks::getNewGenerator(c10::DeviceIndex device_index) const { + // Create and return a new XLA generator using the make_generator template function + return at::make_generator(device_index); +} + +} // namespace torch_xla::detail + +// Register XLA hooks with PyTorch on module load +namespace at { +REGISTER_XLA_HOOKS(torch_xla::detail::XLAHooks) +} // namespace at diff --git a/torch_xla/csrc/xla_hooks.h b/torch_xla/csrc/xla_hooks.h new file mode 100644 index 000000000000..f56c039a8a95 --- /dev/null +++ b/torch_xla/csrc/xla_hooks.h @@ -0,0 +1,40 @@ +#pragma once + +#include + +// PyTorch integration headers +#include +#include +#include +#include +#include + +namespace torch_xla::detail { + +// XLA hooks implementation following PyTorch patterns +struct XLAHooks : public at::XLAHooksInterface { + XLAHooks(const at::XLAHooksArgs& args) {} + + // Core accelerator interface methods + void init() const override; + bool hasXLA() const override; + bool isAvailable() const override; + bool isBuilt() const override { return true; } + std::string showConfig() const override; + + // Device management + c10::DeviceIndex deviceCount() const override; + c10::DeviceIndex getCurrentDevice() const override; + bool hasPrimaryContext(c10::DeviceIndex device_index) const override; + + // Memory management + bool isPinnedPtr(const void* data) const override; + c10::Allocator* getPinnedMemoryAllocator() const override; + c10::Device getDeviceFromPtr(void* data) const override; + + // Generator methods + const at::Generator& getDefaultGenerator(c10::DeviceIndex device_index = -1) const override; + at::Generator getNewGenerator(c10::DeviceIndex device_index = -1) const override; +}; + +} // namespace torch_xla::detail From 914d708989aac940caa5a8b6c5ce15a8705f82bc Mon Sep 17 00:00:00 2001 From: Frank Liu Date: Mon, 20 Oct 2025 03:39:41 +0000 Subject: [PATCH 05/12] format --- torch_xla/csrc/xla_generator.cpp | 24 +++++++++++++----------- torch_xla/csrc/xla_generator.h | 2 +- 2 files changed, 14 insertions(+), 12 deletions(-) diff --git a/torch_xla/csrc/xla_generator.cpp b/torch_xla/csrc/xla_generator.cpp index 5102be5df4ee..56aa2fe3bf49 100644 --- a/torch_xla/csrc/xla_generator.cpp +++ b/torch_xla/csrc/xla_generator.cpp @@ -5,19 +5,19 @@ #include #include #include -#include #include -#include +#include #include +#include // XLA headers -#include "torch_xla/csrc/runtime/computation_client.h" -#include "torch_xla/csrc/aten_xla_bridge.h" - #include #include #include +#include "torch_xla/csrc/aten_xla_bridge.h" +#include "torch_xla/csrc/runtime/computation_client.h" + namespace at { namespace detail { @@ -55,7 +55,7 @@ static void initXLAGenVector() { }(); } -} // anonymous namespace +} // anonymous namespace /** * PyTorch maintains a collection of default generators that get @@ -69,7 +69,7 @@ const at::Generator& getDefaultXLAGenerator(c10::DeviceIndex device_index) { initXLAGenVector(); c10::DeviceIndex idx = device_index; if (idx == -1) { - idx = 0; // Default to device 0 for XLA + idx = 0; // Default to device 0 for XLA } else { TORCH_CHECK(idx >= 0 && idx < num_xla_devices); } @@ -87,17 +87,19 @@ at::Generator createXLAGenerator(c10::DeviceIndex device_index) { initXLAGenVector(); c10::DeviceIndex idx = device_index; if (idx == -1) { - idx = torch_xla::bridge::GetCurrentAtenDevice().index(); // Use current XLA device + idx = torch_xla::bridge::GetCurrentAtenDevice() + .index(); // Use current XLA device } - TORCH_CHECK(idx >= 0 && idx < num_xla_devices, "The device_index is invalid."); + TORCH_CHECK(idx >= 0 && idx < num_xla_devices, + "The device_index is invalid."); auto gen = at::make_generator(idx); auto xla_gen = at::check_generator(gen); xla_gen->set_current_seed(c10::default_rng_seed_val); return gen; } -} // namespace detail -} // namespace at +} // namespace detail +} // namespace at namespace at { diff --git a/torch_xla/csrc/xla_generator.h b/torch_xla/csrc/xla_generator.h index 62621f7c37c9..0d0173157dfd 100644 --- a/torch_xla/csrc/xla_generator.h +++ b/torch_xla/csrc/xla_generator.h @@ -60,6 +60,6 @@ namespace detail { const at::Generator& getDefaultXLAGenerator(c10::DeviceIndex device_index = -1); at::Generator createXLAGenerator(c10::DeviceIndex device_index = -1); -} // namespace detail +} // namespace detail } // namespace at From 1224531fd0aad36177cc3b765a840cc4bf39ed2c Mon Sep 17 00:00:00 2001 From: Frank Liu Date: Mon, 20 Oct 2025 03:42:57 +0000 Subject: [PATCH 06/12] Revert "implement `XLAHooks` and register it to PyTorch when loaded." This reverts commit 79d4b4236aebd2f4b7624237015f30ca4be4b15f. --- torch_xla/csrc/BUILD | 19 ------- torch_xla/csrc/xla_hooks.cpp | 99 ------------------------------------ torch_xla/csrc/xla_hooks.h | 40 --------------- 3 files changed, 158 deletions(-) delete mode 100644 torch_xla/csrc/xla_hooks.cpp delete mode 100644 torch_xla/csrc/xla_hooks.h diff --git a/torch_xla/csrc/BUILD b/torch_xla/csrc/BUILD index 9c8455be184a..8132ae733160 100644 --- a/torch_xla/csrc/BUILD +++ b/torch_xla/csrc/BUILD @@ -270,7 +270,6 @@ ptxla_cc_library( ":status", ":tensor", ":version", - ":xla_hooks", "//torch_xla/csrc/runtime", "//torch_xla/csrc/runtime:pjrt_computation_client", "//torch_xla/csrc/runtime:metrics", @@ -375,21 +374,3 @@ cc_library( "@com_google_absl//absl/status:statusor", ], ) - -ptxla_cc_library( - name = "xla_hooks", - srcs = [ - "xla_hooks.cpp", - ], - hdrs = [ - "xla_hooks.h", - ], - deps = [ - "//torch_xla/csrc:device", - "//torch_xla/csrc:tensor", - "//torch_xla/csrc/runtime:computation_client", - "//torch_xla/csrc/runtime", - "//torch_xla/csrc/runtime:xla_util", - ], -) - diff --git a/torch_xla/csrc/xla_hooks.cpp b/torch_xla/csrc/xla_hooks.cpp deleted file mode 100644 index 257dc7677fef..000000000000 --- a/torch_xla/csrc/xla_hooks.cpp +++ /dev/null @@ -1,99 +0,0 @@ -#include "xla_hooks.h" - -#include -#include - -// PyTorch integration headers -#include -#include -#include -#include -#include -#include -#include - -// XLA headers -#include "xla_generator.h" -#include "xla_backend_impl.h" -#include "torch_xla/csrc/aten_xla_bridge.h" -#include "torch_xla/csrc/runtime/computation_client.h" -#include "torch_xla/csrc/runtime/debug_macros.h" -#include "torch_xla/csrc/runtime/runtime.h" - - -namespace torch_xla::detail { - -void XLAHooks::init() const { - C10_LOG_API_USAGE_ONCE("aten.init.xla"); - - // Initialize XLA backend - this registers XLA functions and sets up - // the backend infrastructure - torch_xla::InitXlaBackend(); -} - -bool XLAHooks::hasXLA() const { - return isAvailable(); -} - -bool XLAHooks::isAvailable() const { - try { - return deviceCount() > 0; - } catch (...) { - // If device enumeration fails, XLA is not available - return false; - } -} - -std::string XLAHooks::showConfig() const { - std::ostringstream oss; - oss << "XLA Backend Configuration:\n"; - oss << " - XLA devices available: " << deviceCount() << "\n"; - return oss.str(); -} - -c10::DeviceIndex XLAHooks::deviceCount() const { - auto maybe_client = torch_xla::runtime::GetComputationClient(); - if (!maybe_client.ok()) { - // If runtime client initialization failed, return 0 devices - return 0; - } - - auto* client = maybe_client.value(); - return static_cast(client->GetNumDevices()); -} - -c10::DeviceIndex XLAHooks::getCurrentDevice() const { - return bridge::GetCurrentAtenDevice().index(); -} - -bool XLAHooks::hasPrimaryContext(c10::DeviceIndex device_index) const { - TORCH_CHECK(false, "hasPrimaryContext is not implemented."); -} - -bool XLAHooks::isPinnedPtr(const void* data) const { - TORCH_CHECK(false, "isPinnedPtr is not implemented."); -} - -c10::Allocator* XLAHooks::getPinnedMemoryAllocator() const { - TORCH_CHECK(false, "getPinnedMemoryAllocator is not implemented."); -} - -c10::Device XLAHooks::getDeviceFromPtr(void* data) const { - TORCH_CHECK(false, "getDeviceFromPtr is not implemented."); -} - -const at::Generator& XLAHooks::getDefaultGenerator(c10::DeviceIndex device_index) const { - return at::detail::getDefaultXLAGenerator(device_index); -} - -at::Generator XLAHooks::getNewGenerator(c10::DeviceIndex device_index) const { - // Create and return a new XLA generator using the make_generator template function - return at::make_generator(device_index); -} - -} // namespace torch_xla::detail - -// Register XLA hooks with PyTorch on module load -namespace at { -REGISTER_XLA_HOOKS(torch_xla::detail::XLAHooks) -} // namespace at diff --git a/torch_xla/csrc/xla_hooks.h b/torch_xla/csrc/xla_hooks.h deleted file mode 100644 index f56c039a8a95..000000000000 --- a/torch_xla/csrc/xla_hooks.h +++ /dev/null @@ -1,40 +0,0 @@ -#pragma once - -#include - -// PyTorch integration headers -#include -#include -#include -#include -#include - -namespace torch_xla::detail { - -// XLA hooks implementation following PyTorch patterns -struct XLAHooks : public at::XLAHooksInterface { - XLAHooks(const at::XLAHooksArgs& args) {} - - // Core accelerator interface methods - void init() const override; - bool hasXLA() const override; - bool isAvailable() const override; - bool isBuilt() const override { return true; } - std::string showConfig() const override; - - // Device management - c10::DeviceIndex deviceCount() const override; - c10::DeviceIndex getCurrentDevice() const override; - bool hasPrimaryContext(c10::DeviceIndex device_index) const override; - - // Memory management - bool isPinnedPtr(const void* data) const override; - c10::Allocator* getPinnedMemoryAllocator() const override; - c10::Device getDeviceFromPtr(void* data) const override; - - // Generator methods - const at::Generator& getDefaultGenerator(c10::DeviceIndex device_index = -1) const override; - at::Generator getNewGenerator(c10::DeviceIndex device_index = -1) const override; -}; - -} // namespace torch_xla::detail From 80b2078302ccc600c96d688fbf7132efa5339094 Mon Sep 17 00:00:00 2001 From: Frank Liu Date: Tue, 21 Oct 2025 03:41:56 +0000 Subject: [PATCH 07/12] Add missing include --- torch_xla/csrc/xla_generator.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/torch_xla/csrc/xla_generator.cpp b/torch_xla/csrc/xla_generator.cpp index 56aa2fe3bf49..e7d2115552a8 100644 --- a/torch_xla/csrc/xla_generator.cpp +++ b/torch_xla/csrc/xla_generator.cpp @@ -17,6 +17,7 @@ #include "torch_xla/csrc/aten_xla_bridge.h" #include "torch_xla/csrc/runtime/computation_client.h" +#include "torch_xla/csrc/runtime/runtime.h" namespace at { From c65ce8f157354760aa72ea5bb2eafddb3aa4ab7c Mon Sep 17 00:00:00 2001 From: Frank Liu Date: Wed, 29 Oct 2025 17:48:59 +0000 Subject: [PATCH 08/12] improve error reporting and function naming. --- torch_xla/csrc/xla_generator.cpp | 34 ++++++++++++++------------------ torch_xla/csrc/xla_generator.h | 10 +++++++--- 2 files changed, 22 insertions(+), 22 deletions(-) diff --git a/torch_xla/csrc/xla_generator.cpp b/torch_xla/csrc/xla_generator.cpp index e7d2115552a8..d69d740768ad 100644 --- a/torch_xla/csrc/xla_generator.cpp +++ b/torch_xla/csrc/xla_generator.cpp @@ -10,14 +10,15 @@ #include #include -// XLA headers #include #include #include +#include "absl/status/status.h" #include "torch_xla/csrc/aten_xla_bridge.h" #include "torch_xla/csrc/runtime/computation_client.h" #include "torch_xla/csrc/runtime/runtime.h" +#include "torch_xla/csrc/status.h" namespace at { @@ -38,18 +39,12 @@ static std::vector default_gens_xla; * Populates the global variables related to XLA generators * Warning: this function must only be called once! */ -static void initXLAGenVector() { +static void InitXLAGenVector() { // Ensures we only call deviceCount only once. static bool num_xla_device_init_flag [[maybe_unused]] = []() { // Get local num of XLA devices - auto maybe_client = torch_xla::runtime::GetComputationClient(); - if (!maybe_client.ok()) { - // If runtime client initialization failed, default to 1 device - num_xla_devices = 1; - } else { - auto* client = maybe_client.value(); - num_xla_devices = static_cast(client->GetNumDevices()); - } + XLA_ASSIGN_OR_THROW(auto c_client, torch_xla::runtime::GetComputationClient()); + num_xla_devices = static_cast(c_client->GetNumDevices()); xla_gens_init_flag.resize(num_xla_devices); default_gens_xla.resize(num_xla_devices); return true; @@ -63,16 +58,16 @@ static void initXLAGenVector() { * initialized once. The purpose of these default generators is to * maintain a global running state of the pseudo random number generation, * when a user does not explicitly mention any generator. - * getDefaultXLAGenerator gets the default generator for a particular + * GetDefaultXLAGenerator gets the default generator for a particular * XLA device. */ -const at::Generator& getDefaultXLAGenerator(c10::DeviceIndex device_index) { - initXLAGenVector(); +absl::StatusOr GetDefaultXLAGenerator(c10::DeviceIndex device_index) { + InitXLAGenVector(); c10::DeviceIndex idx = device_index; if (idx == -1) { idx = 0; // Default to device 0 for XLA - } else { - TORCH_CHECK(idx >= 0 && idx < num_xla_devices); + } else if (idx < -1 || idx >= num_xla_devices) { + return absl::InvalidArgumentError("Invalid device index for XLA generator. Provided index: " + std::to_string(idx)); } c10::call_once(xla_gens_init_flag[idx], [&] { default_gens_xla[idx] = at::make_generator(idx); @@ -84,15 +79,16 @@ const at::Generator& getDefaultXLAGenerator(c10::DeviceIndex device_index) { /** * Utility to create a XLAGeneratorImpl. Returns a shared_ptr */ -at::Generator createXLAGenerator(c10::DeviceIndex device_index) { - initXLAGenVector(); +absl::StatusOr CreateXLAGenerator(c10::DeviceIndex device_index) { + InitXLAGenVector(); c10::DeviceIndex idx = device_index; if (idx == -1) { idx = torch_xla::bridge::GetCurrentAtenDevice() .index(); // Use current XLA device } - TORCH_CHECK(idx >= 0 && idx < num_xla_devices, - "The device_index is invalid."); + else if (idx < -1 || idx >= num_xla_devices) { + return absl::InvalidArgumentError("Invalid device index for XLA generator. Provided index: " + std::to_string(idx)); + } auto gen = at::make_generator(idx); auto xla_gen = at::check_generator(gen); xla_gen->set_current_seed(c10::default_rng_seed_val); diff --git a/torch_xla/csrc/xla_generator.h b/torch_xla/csrc/xla_generator.h index 0d0173157dfd..0dce39ddee16 100644 --- a/torch_xla/csrc/xla_generator.h +++ b/torch_xla/csrc/xla_generator.h @@ -4,10 +4,14 @@ #include #include #include +#include +#include #include - #include +#include "absl/status/status.h" +#include "absl/status/statusor.h" + namespace at { // Holds the actual state variables for the XLA generator. @@ -57,8 +61,8 @@ struct TORCH_API XLAGeneratorImpl : public c10::GeneratorImpl { namespace detail { -const at::Generator& getDefaultXLAGenerator(c10::DeviceIndex device_index = -1); -at::Generator createXLAGenerator(c10::DeviceIndex device_index = -1); +absl::StatusOr GetDefaultXLAGenerator(c10::DeviceIndex device_index = -1); +absl::StatusOr CreateXLAGenerator(c10::DeviceIndex device_index = -1); } // namespace detail From e7453366b7533022052eedb4938364a4d33a376e Mon Sep 17 00:00:00 2001 From: Frank Liu Date: Wed, 29 Oct 2025 17:53:04 +0000 Subject: [PATCH 09/12] format --- torch_xla/csrc/xla_generator.cpp | 20 +++++++++++++------- torch_xla/csrc/xla_generator.h | 7 +++++-- 2 files changed, 18 insertions(+), 9 deletions(-) diff --git a/torch_xla/csrc/xla_generator.cpp b/torch_xla/csrc/xla_generator.cpp index d69d740768ad..9099639dc4a2 100644 --- a/torch_xla/csrc/xla_generator.cpp +++ b/torch_xla/csrc/xla_generator.cpp @@ -43,7 +43,8 @@ static void InitXLAGenVector() { // Ensures we only call deviceCount only once. static bool num_xla_device_init_flag [[maybe_unused]] = []() { // Get local num of XLA devices - XLA_ASSIGN_OR_THROW(auto c_client, torch_xla::runtime::GetComputationClient()); + XLA_ASSIGN_OR_THROW(auto c_client, + torch_xla::runtime::GetComputationClient()); num_xla_devices = static_cast(c_client->GetNumDevices()); xla_gens_init_flag.resize(num_xla_devices); default_gens_xla.resize(num_xla_devices); @@ -61,13 +62,16 @@ static void InitXLAGenVector() { * GetDefaultXLAGenerator gets the default generator for a particular * XLA device. */ -absl::StatusOr GetDefaultXLAGenerator(c10::DeviceIndex device_index) { +absl::StatusOr GetDefaultXLAGenerator( + c10::DeviceIndex device_index) { InitXLAGenVector(); c10::DeviceIndex idx = device_index; if (idx == -1) { idx = 0; // Default to device 0 for XLA } else if (idx < -1 || idx >= num_xla_devices) { - return absl::InvalidArgumentError("Invalid device index for XLA generator. Provided index: " + std::to_string(idx)); + return absl::InvalidArgumentError( + "Invalid device index for XLA generator. Provided index: " + + std::to_string(idx)); } c10::call_once(xla_gens_init_flag[idx], [&] { default_gens_xla[idx] = at::make_generator(idx); @@ -79,15 +83,17 @@ absl::StatusOr GetDefaultXLAGenerator(c10::DeviceIndex dev /** * Utility to create a XLAGeneratorImpl. Returns a shared_ptr */ -absl::StatusOr CreateXLAGenerator(c10::DeviceIndex device_index) { +absl::StatusOr CreateXLAGenerator( + c10::DeviceIndex device_index) { InitXLAGenVector(); c10::DeviceIndex idx = device_index; if (idx == -1) { idx = torch_xla::bridge::GetCurrentAtenDevice() .index(); // Use current XLA device - } - else if (idx < -1 || idx >= num_xla_devices) { - return absl::InvalidArgumentError("Invalid device index for XLA generator. Provided index: " + std::to_string(idx)); + } else if (idx < -1 || idx >= num_xla_devices) { + return absl::InvalidArgumentError( + "Invalid device index for XLA generator. Provided index: " + + std::to_string(idx)); } auto gen = at::make_generator(idx); auto xla_gen = at::check_generator(gen); diff --git a/torch_xla/csrc/xla_generator.h b/torch_xla/csrc/xla_generator.h index 0dce39ddee16..8001737e795c 100644 --- a/torch_xla/csrc/xla_generator.h +++ b/torch_xla/csrc/xla_generator.h @@ -7,6 +7,7 @@ #include #include #include + #include #include "absl/status/status.h" @@ -61,8 +62,10 @@ struct TORCH_API XLAGeneratorImpl : public c10::GeneratorImpl { namespace detail { -absl::StatusOr GetDefaultXLAGenerator(c10::DeviceIndex device_index = -1); -absl::StatusOr CreateXLAGenerator(c10::DeviceIndex device_index = -1); +absl::StatusOr GetDefaultXLAGenerator( + c10::DeviceIndex device_index = -1); +absl::StatusOr CreateXLAGenerator( + c10::DeviceIndex device_index = -1); } // namespace detail From e89e659ca440e7db9d7d51e75e8cfe88d18901ca Mon Sep 17 00:00:00 2001 From: Frank Liu Date: Fri, 31 Oct 2025 20:52:36 +0000 Subject: [PATCH 10/12] Add unit tests for `GetDefaultXLAGenerator` and `CreateXLAGenerator` --- test/cpp/test_xla_generator.cpp | 119 +++++++++++++++++++++++++++++++- 1 file changed, 118 insertions(+), 1 deletion(-) diff --git a/test/cpp/test_xla_generator.cpp b/test/cpp/test_xla_generator.cpp index d45991f72d39..14663d3cb377 100644 --- a/test/cpp/test_xla_generator.cpp +++ b/test/cpp/test_xla_generator.cpp @@ -1,6 +1,8 @@ #include #include +#include + #include "test/cpp/torch_xla_test.h" #include "torch_xla/csrc/xla_generator.h" @@ -18,6 +20,20 @@ class XLAGeneratorTest : public ::torch_xla::cpp_test::TorchXlaTest { at::Generator gen_; }; +// Ensure PJRT is configured to a CPU backend for tests that touch the PJRT +// runtime. +static void EnsurePjrtCpuBackend() { + const char* pjrt = std::getenv("PJRT_DEVICE"); + if (pjrt == nullptr || pjrt[0] == '\0') { + // Use CPU backend with a single device by default. + setenv("PJRT_DEVICE", "CPU", 1); + } + const char* cpu_devices = std::getenv("CPU_NUM_DEVICES"); + if (cpu_devices == nullptr || cpu_devices[0] == '\0') { + setenv("CPU_NUM_DEVICES", "1", 0); + } +} + TEST_F(XLAGeneratorTest, Constructor) { // Check that the generator was created for the correct device ASSERT_EQ(gen_.device().type(), at::DeviceType::XLA); @@ -102,5 +118,106 @@ TEST_F(XLAGeneratorTest, Clone) { ASSERT_NE(cloned_gen.current_seed(), gen_.current_seed()); } +TEST_F(XLAGeneratorTest, GetDefaultXLAGenerator) { + EnsurePjrtCpuBackend(); + // Test getting default generator for device 0 + auto result = at::detail::GetDefaultXLAGenerator(0); + ASSERT_TRUE(result.ok()) << "Failed to get default generator: " + << result.status(); + + const at::Generator& default_gen = result.value(); + ASSERT_EQ(default_gen.device().type(), at::DeviceType::XLA); + ASSERT_EQ(default_gen.device().index(), 0); + + // Test getting default generator with -1 (should default to device 0) + auto result_default = at::detail::GetDefaultXLAGenerator(-1); + ASSERT_TRUE(result_default.ok()) + << "Failed to get default generator with -1: " << result_default.status(); + + const at::Generator& default_gen_neg1 = result_default.value(); + ASSERT_EQ(default_gen_neg1.device().type(), at::DeviceType::XLA); + ASSERT_EQ(default_gen_neg1.device().index(), 0); + + // Test that subsequent calls return the same generator instance + auto result2 = at::detail::GetDefaultXLAGenerator(0); + ASSERT_TRUE(result2.ok()); + const at::Generator& default_gen2 = result2.value(); + ASSERT_EQ(std::addressof(default_gen), std::addressof(default_gen2)); +} + +TEST_F(XLAGeneratorTest, GetDefaultXLAGeneratorInvalidDevice) { + EnsurePjrtCpuBackend(); + // Test with invalid device indices + auto result_neg2 = at::detail::GetDefaultXLAGenerator(-2); + ASSERT_FALSE(result_neg2.ok()); + ASSERT_TRUE(absl::IsInvalidArgument(result_neg2.status())); + + // Test with very large device index (assuming there aren't 1000 XLA devices) + auto result_large = at::detail::GetDefaultXLAGenerator(1000); + ASSERT_FALSE(result_large.ok()); + ASSERT_TRUE(absl::IsInvalidArgument(result_large.status())); +} + +TEST_F(XLAGeneratorTest, CreateXLAGenerator) { + EnsurePjrtCpuBackend(); + // Test creating generator for device 0 + auto result = at::detail::CreateXLAGenerator(0); + ASSERT_TRUE(result.ok()) << "Failed to create generator: " << result.status(); + + at::Generator created_gen = result.value(); + ASSERT_EQ(created_gen.device().type(), at::DeviceType::XLA); + ASSERT_EQ(created_gen.device().index(), 0); + + // Test that the generator is initialized with default seed + ASSERT_EQ(created_gen.current_seed(), c10::default_rng_seed_val); + + // Test creating generator with -1 (should use current device) + auto result_default = at::detail::CreateXLAGenerator(-1); + ASSERT_TRUE(result_default.ok()) + << "Failed to create generator with -1: " << result_default.status(); + + at::Generator created_gen_neg1 = result_default.value(); + ASSERT_EQ(created_gen_neg1.device().type(), at::DeviceType::XLA); + // Device index should be >= 0 (actual device depends on current XLA device) + ASSERT_GE(created_gen_neg1.device().index(), 0); +} + +TEST_F(XLAGeneratorTest, CreateXLAGeneratorUniqueness) { + EnsurePjrtCpuBackend(); + // Test that each call creates a new generator instance + auto result1 = at::detail::CreateXLAGenerator(0); + auto result2 = at::detail::CreateXLAGenerator(0); + + ASSERT_TRUE(result1.ok()); + ASSERT_TRUE(result2.ok()); + + at::Generator gen1 = result1.value(); + at::Generator gen2 = result2.value(); + + // Should be different instances + ASSERT_NE(std::addressof(gen1), std::addressof(gen2)); + + // But should have same device and initial seed + ASSERT_EQ(gen1.device(), gen2.device()); + ASSERT_EQ(gen1.current_seed(), gen2.current_seed()); + + // Modifying one should not affect the other + gen1.set_current_seed(12345); + ASSERT_NE(gen1.current_seed(), gen2.current_seed()); +} + +TEST_F(XLAGeneratorTest, CreateXLAGeneratorInvalidDevice) { + EnsurePjrtCpuBackend(); + // Test with invalid device indices + auto result_neg2 = at::detail::CreateXLAGenerator(-2); + ASSERT_FALSE(result_neg2.ok()); + ASSERT_TRUE(absl::IsInvalidArgument(result_neg2.status())); + + // Test with very large device index (assuming there aren't 1000 XLA devices) + auto result_large = at::detail::CreateXLAGenerator(1000); + ASSERT_FALSE(result_large.ok()); + ASSERT_TRUE(absl::IsInvalidArgument(result_large.status())); +} + } // namespace cpp_test -} // namespace torch_xla \ No newline at end of file +} // namespace torch_xla From 4cc594fb1ac92fb21566e27a325b0006462273d2 Mon Sep 17 00:00:00 2001 From: Frank Liu Date: Tue, 4 Nov 2025 00:33:10 +0000 Subject: [PATCH 11/12] Improve InitXLAGenVector function and the unit tests accordingly. --- test/cpp/test_xla_generator.cpp | 97 +++++++++++++++++++++++--------- torch_xla/csrc/xla_generator.cpp | 71 ++++++++++++++--------- 2 files changed, 115 insertions(+), 53 deletions(-) diff --git a/test/cpp/test_xla_generator.cpp b/test/cpp/test_xla_generator.cpp index 14663d3cb377..0dbf4d83a0c0 100644 --- a/test/cpp/test_xla_generator.cpp +++ b/test/cpp/test_xla_generator.cpp @@ -1,3 +1,4 @@ +#include #include #include @@ -9,9 +10,45 @@ namespace torch_xla { namespace cpp_test { +// Ensure PJRT is configured to a CPU backend for tests that touch the PJRT +// runtime. Optionally allow overriding the environment values by passing +// `pjrt_device` and/or `cpu_num_devices`. +static void EnsurePjrtCpuBackend(const char* pjrt_device = nullptr, + const char* cpu_num_devices = nullptr) { + // PJRT_DEVICE: override if provided, otherwise set default if not present + if (pjrt_device != nullptr && pjrt_device[0] != '\0') { + // Force override of any existing value + setenv("PJRT_DEVICE", pjrt_device, 1); + } else { + const char* pjrt = std::getenv("PJRT_DEVICE"); + if (pjrt == nullptr || pjrt[0] == '\0') { + // Use CPU backend with a single device by default. + setenv("PJRT_DEVICE", "CPU", 1); + } + } + + // CPU_NUM_DEVICES: override if provided, otherwise set default if not present + if (cpu_num_devices != nullptr && cpu_num_devices[0] != '\0') { + // Force override of any existing value + setenv("CPU_NUM_DEVICES", cpu_num_devices, 1); + } else { + const char* cpu_devices = std::getenv("CPU_NUM_DEVICES"); + if (cpu_devices == nullptr || cpu_devices[0] == '\0') { + // Default to a single CPU device. Preserve existing behavior of not + // overwriting if already present (use overwrite=0 to match previous + // semantics). + setenv("CPU_NUM_DEVICES", "1", 0); + } + } +} + // Test fixture for XLAGenerator tests class XLAGeneratorTest : public ::torch_xla::cpp_test::TorchXlaTest { protected: + // Runs once before the test suite / test case to ensure PJRT is configured + // before any XLA runtime initialization happens in per-test SetUp(). + static void SetUpTestCase() { EnsurePjrtCpuBackend("CPU", "2"); } + void SetUp() { // Create a generator for XLA device 0 gen_ = at::make_generator(0); @@ -20,20 +57,6 @@ class XLAGeneratorTest : public ::torch_xla::cpp_test::TorchXlaTest { at::Generator gen_; }; -// Ensure PJRT is configured to a CPU backend for tests that touch the PJRT -// runtime. -static void EnsurePjrtCpuBackend() { - const char* pjrt = std::getenv("PJRT_DEVICE"); - if (pjrt == nullptr || pjrt[0] == '\0') { - // Use CPU backend with a single device by default. - setenv("PJRT_DEVICE", "CPU", 1); - } - const char* cpu_devices = std::getenv("CPU_NUM_DEVICES"); - if (cpu_devices == nullptr || cpu_devices[0] == '\0') { - setenv("CPU_NUM_DEVICES", "1", 0); - } -} - TEST_F(XLAGeneratorTest, Constructor) { // Check that the generator was created for the correct device ASSERT_EQ(gen_.device().type(), at::DeviceType::XLA); @@ -142,7 +165,18 @@ TEST_F(XLAGeneratorTest, GetDefaultXLAGenerator) { auto result2 = at::detail::GetDefaultXLAGenerator(0); ASSERT_TRUE(result2.ok()); const at::Generator& default_gen2 = result2.value(); - ASSERT_EQ(std::addressof(default_gen), std::addressof(default_gen2)); + ASSERT_EQ(default_gen, default_gen2); + + // Test getting non-defuault device generator + auto result_device1 = at::detail::GetDefaultXLAGenerator(1); + ASSERT_TRUE(result_device1.ok()) + << "Failed to get default generator for device 1: " + << result_device1.status(); + + const at::Generator& default_gen_device1 = result_device1.value(); + ASSERT_EQ(default_gen_device1.device().type(), at::DeviceType::XLA); + ASSERT_EQ(default_gen_device1.device().index(), 1); + ASSERT_NE(default_gen_device1, default_gen); } TEST_F(XLAGeneratorTest, GetDefaultXLAGeneratorInvalidDevice) { @@ -151,32 +185,36 @@ TEST_F(XLAGeneratorTest, GetDefaultXLAGeneratorInvalidDevice) { auto result_neg2 = at::detail::GetDefaultXLAGenerator(-2); ASSERT_FALSE(result_neg2.ok()); ASSERT_TRUE(absl::IsInvalidArgument(result_neg2.status())); + ASSERT_THAT(result_neg2.status().message(), + testing::HasSubstr("Invalid XLA device index")); // Test with very large device index (assuming there aren't 1000 XLA devices) - auto result_large = at::detail::GetDefaultXLAGenerator(1000); + auto result_large = at::detail::GetDefaultXLAGenerator(100); ASSERT_FALSE(result_large.ok()); ASSERT_TRUE(absl::IsInvalidArgument(result_large.status())); + ASSERT_THAT(result_large.status().message(), + testing::HasSubstr("Invalid XLA device index")); } TEST_F(XLAGeneratorTest, CreateXLAGenerator) { - EnsurePjrtCpuBackend(); + EnsurePjrtCpuBackend("CPU", "2"); // Test creating generator for device 0 - auto result = at::detail::CreateXLAGenerator(0); + auto result = at::detail::CreateXLAGenerator(1); ASSERT_TRUE(result.ok()) << "Failed to create generator: " << result.status(); at::Generator created_gen = result.value(); ASSERT_EQ(created_gen.device().type(), at::DeviceType::XLA); - ASSERT_EQ(created_gen.device().index(), 0); + ASSERT_EQ(created_gen.device().index(), 1); // Test that the generator is initialized with default seed ASSERT_EQ(created_gen.current_seed(), c10::default_rng_seed_val); // Test creating generator with -1 (should use current device) - auto result_default = at::detail::CreateXLAGenerator(-1); - ASSERT_TRUE(result_default.ok()) - << "Failed to create generator with -1: " << result_default.status(); + auto result_current = at::detail::CreateXLAGenerator(-1); + ASSERT_TRUE(result_current.ok()) + << "Failed to create generator with -1: " << result_current.status(); - at::Generator created_gen_neg1 = result_default.value(); + at::Generator created_gen_neg1 = result_current.value(); ASSERT_EQ(created_gen_neg1.device().type(), at::DeviceType::XLA); // Device index should be >= 0 (actual device depends on current XLA device) ASSERT_GE(created_gen_neg1.device().index(), 0); @@ -194,8 +232,9 @@ TEST_F(XLAGeneratorTest, CreateXLAGeneratorUniqueness) { at::Generator gen1 = result1.value(); at::Generator gen2 = result2.value(); - // Should be different instances - ASSERT_NE(std::addressof(gen1), std::addressof(gen2)); + // Should be different instances (compare generators, not their stack + // addresses) + ASSERT_NE(gen1, gen2); // But should have same device and initial seed ASSERT_EQ(gen1.device(), gen2.device()); @@ -212,11 +251,15 @@ TEST_F(XLAGeneratorTest, CreateXLAGeneratorInvalidDevice) { auto result_neg2 = at::detail::CreateXLAGenerator(-2); ASSERT_FALSE(result_neg2.ok()); ASSERT_TRUE(absl::IsInvalidArgument(result_neg2.status())); + ASSERT_THAT(result_neg2.status().message(), + testing::HasSubstr("Invalid XLA device index")); - // Test with very large device index (assuming there aren't 1000 XLA devices) - auto result_large = at::detail::CreateXLAGenerator(1000); + // Test with very large device index (assuming there aren't 100 XLA devices) + auto result_large = at::detail::CreateXLAGenerator(100); ASSERT_FALSE(result_large.ok()); ASSERT_TRUE(absl::IsInvalidArgument(result_large.status())); + ASSERT_THAT(result_large.status().message(), + testing::HasSubstr("Invalid XLA device index")); } } // namespace cpp_test diff --git a/torch_xla/csrc/xla_generator.cpp b/torch_xla/csrc/xla_generator.cpp index 9099639dc4a2..87ab76f35131 100644 --- a/torch_xla/csrc/xla_generator.cpp +++ b/torch_xla/csrc/xla_generator.cpp @@ -39,17 +39,41 @@ static std::vector default_gens_xla; * Populates the global variables related to XLA generators * Warning: this function must only be called once! */ -static void InitXLAGenVector() { - // Ensures we only call deviceCount only once. - static bool num_xla_device_init_flag [[maybe_unused]] = []() { +static absl::Status InitXLAGenVector() { + // Ensure we only perform initialization once and propagate status + static c10::once_flag init_once_flag; + static absl::Status init_status; + c10::call_once(init_once_flag, [&] { // Get local num of XLA devices - XLA_ASSIGN_OR_THROW(auto c_client, - torch_xla::runtime::GetComputationClient()); + auto c_client_or = torch_xla::runtime::GetComputationClient(); + if (!c_client_or.ok()) { + init_status = c_client_or.status(); + return; + } + auto c_client = *c_client_or; num_xla_devices = static_cast(c_client->GetNumDevices()); xla_gens_init_flag.resize(num_xla_devices); default_gens_xla.resize(num_xla_devices); - return true; - }(); + init_status = absl::OkStatus(); + }); + return init_status; +} + +// Validates and normalizes an XLA device index. +// If requested_index == -1, fallback_index will be used. +// Returns InvalidArgument if the resolved index is out of range. +static absl::StatusOr NormalizeXLADeviceIndex( + c10::DeviceIndex requested_index) { + c10::DeviceIndex idx = requested_index; + if (idx == -1) { + idx = torch_xla::bridge::GetCurrentAtenDevice().index(); + } + if (idx < 0 || idx >= num_xla_devices) { + return absl::InvalidArgumentError( + "Invalid device index for XLA generator. Provided index: " + + std::to_string(idx)); + } + return idx; } } // anonymous namespace @@ -64,15 +88,13 @@ static void InitXLAGenVector() { */ absl::StatusOr GetDefaultXLAGenerator( c10::DeviceIndex device_index) { - InitXLAGenVector(); - c10::DeviceIndex idx = device_index; - if (idx == -1) { - idx = 0; // Default to device 0 for XLA - } else if (idx < -1 || idx >= num_xla_devices) { - return absl::InvalidArgumentError( - "Invalid device index for XLA generator. Provided index: " + - std::to_string(idx)); - } + XLA_RETURN_IF_ERROR(InitXLAGenVector(), + "Failed to initialize XLA generators"); + // Normalize and validate the target device index; default to current device + // when unspecified + XLA_ASSIGN_OR_RETURN(c10::DeviceIndex idx, + NormalizeXLADeviceIndex(device_index), + "Invalid XLA device index"); c10::call_once(xla_gens_init_flag[idx], [&] { default_gens_xla[idx] = at::make_generator(idx); default_gens_xla[idx].seed(); @@ -85,16 +107,13 @@ absl::StatusOr GetDefaultXLAGenerator( */ absl::StatusOr CreateXLAGenerator( c10::DeviceIndex device_index) { - InitXLAGenVector(); - c10::DeviceIndex idx = device_index; - if (idx == -1) { - idx = torch_xla::bridge::GetCurrentAtenDevice() - .index(); // Use current XLA device - } else if (idx < -1 || idx >= num_xla_devices) { - return absl::InvalidArgumentError( - "Invalid device index for XLA generator. Provided index: " + - std::to_string(idx)); - } + XLA_RETURN_IF_ERROR(InitXLAGenVector(), + "Failed to initialize XLA generators"); + // Normalize and validate the target device index; default to current device + // when unspecified + XLA_ASSIGN_OR_RETURN(c10::DeviceIndex idx, + NormalizeXLADeviceIndex(device_index), + "Invalid XLA device index"); auto gen = at::make_generator(idx); auto xla_gen = at::check_generator(gen); xla_gen->set_current_seed(c10::default_rng_seed_val); From c93e10db850424a7152b1a47a7fe1cfd8bef51b5 Mon Sep 17 00:00:00 2001 From: Frank Liu Date: Thu, 6 Nov 2025 17:31:11 +0000 Subject: [PATCH 12/12] Address feedbacks. --- test/cpp/test_xla_generator.cpp | 8 ++------ torch_xla/csrc/xla_generator.cpp | 20 ++++++-------------- 2 files changed, 8 insertions(+), 20 deletions(-) diff --git a/test/cpp/test_xla_generator.cpp b/test/cpp/test_xla_generator.cpp index 0dbf4d83a0c0..c89b0f957bee 100644 --- a/test/cpp/test_xla_generator.cpp +++ b/test/cpp/test_xla_generator.cpp @@ -142,7 +142,6 @@ TEST_F(XLAGeneratorTest, Clone) { } TEST_F(XLAGeneratorTest, GetDefaultXLAGenerator) { - EnsurePjrtCpuBackend(); // Test getting default generator for device 0 auto result = at::detail::GetDefaultXLAGenerator(0); ASSERT_TRUE(result.ok()) << "Failed to get default generator: " @@ -160,6 +159,7 @@ TEST_F(XLAGeneratorTest, GetDefaultXLAGenerator) { const at::Generator& default_gen_neg1 = result_default.value(); ASSERT_EQ(default_gen_neg1.device().type(), at::DeviceType::XLA); ASSERT_EQ(default_gen_neg1.device().index(), 0); + ASSERT_EQ(default_gen, default_gen_neg1); // Test that subsequent calls return the same generator instance auto result2 = at::detail::GetDefaultXLAGenerator(0); @@ -180,7 +180,6 @@ TEST_F(XLAGeneratorTest, GetDefaultXLAGenerator) { } TEST_F(XLAGeneratorTest, GetDefaultXLAGeneratorInvalidDevice) { - EnsurePjrtCpuBackend(); // Test with invalid device indices auto result_neg2 = at::detail::GetDefaultXLAGenerator(-2); ASSERT_FALSE(result_neg2.ok()); @@ -197,8 +196,7 @@ TEST_F(XLAGeneratorTest, GetDefaultXLAGeneratorInvalidDevice) { } TEST_F(XLAGeneratorTest, CreateXLAGenerator) { - EnsurePjrtCpuBackend("CPU", "2"); - // Test creating generator for device 0 + // Test creating generator for device 1 auto result = at::detail::CreateXLAGenerator(1); ASSERT_TRUE(result.ok()) << "Failed to create generator: " << result.status(); @@ -221,7 +219,6 @@ TEST_F(XLAGeneratorTest, CreateXLAGenerator) { } TEST_F(XLAGeneratorTest, CreateXLAGeneratorUniqueness) { - EnsurePjrtCpuBackend(); // Test that each call creates a new generator instance auto result1 = at::detail::CreateXLAGenerator(0); auto result2 = at::detail::CreateXLAGenerator(0); @@ -246,7 +243,6 @@ TEST_F(XLAGeneratorTest, CreateXLAGeneratorUniqueness) { } TEST_F(XLAGeneratorTest, CreateXLAGeneratorInvalidDevice) { - EnsurePjrtCpuBackend(); // Test with invalid device indices auto result_neg2 = at::detail::CreateXLAGenerator(-2); ASSERT_FALSE(result_neg2.ok()); diff --git a/torch_xla/csrc/xla_generator.cpp b/torch_xla/csrc/xla_generator.cpp index 87ab76f35131..0e311bda2632 100644 --- a/torch_xla/csrc/xla_generator.cpp +++ b/torch_xla/csrc/xla_generator.cpp @@ -40,27 +40,19 @@ static std::vector default_gens_xla; * Warning: this function must only be called once! */ static absl::Status InitXLAGenVector() { - // Ensure we only perform initialization once and propagate status - static c10::once_flag init_once_flag; - static absl::Status init_status; - c10::call_once(init_once_flag, [&] { - // Get local num of XLA devices - auto c_client_or = torch_xla::runtime::GetComputationClient(); - if (!c_client_or.ok()) { - init_status = c_client_or.status(); - return; - } - auto c_client = *c_client_or; + static absl::Status init_status = []() { + XLA_ASSIGN_OR_RETURN(auto c_client, + torch_xla::runtime::GetComputationClient()); num_xla_devices = static_cast(c_client->GetNumDevices()); xla_gens_init_flag.resize(num_xla_devices); default_gens_xla.resize(num_xla_devices); - init_status = absl::OkStatus(); - }); + return absl::OkStatus(); + }(); return init_status; } // Validates and normalizes an XLA device index. -// If requested_index == -1, fallback_index will be used. +// If requested_index == -1, the current device index is used. // Returns InvalidArgument if the resolved index is out of range. static absl::StatusOr NormalizeXLADeviceIndex( c10::DeviceIndex requested_index) {