Skip to content
Open
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 92 additions & 0 deletions torch_xla/csrc/xla_generator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,102 @@
#include <ATen/core/Tensor.h>
#include <c10/core/Device.h>
#include <c10/core/DeviceType.h>
#include <c10/core/GeneratorImpl.h>
#include <c10/core/TensorImpl.h>
#include <c10/util/CallOnce.h>
#include <c10/util/intrusive_ptr.h>

// XLA headers
#include <cstring>
#include <deque>
#include <vector>

#include "torch_xla/csrc/aten_xla_bridge.h"
#include "torch_xla/csrc/runtime/computation_client.h"
#include "torch_xla/csrc/runtime/runtime.h"

namespace at {

namespace detail {

namespace {

// Total number of XLA devices in the system.
static int64_t num_xla_devices;

// Ensures default_gens_xla is initialized once.
static std::deque<c10::once_flag> xla_gens_init_flag;

// Default, global XLA generators, one per XLA device.
static std::vector<at::Generator> default_gens_xla;

/*
* Populates the global variables related to XLA generators
* Warning: this function must only be called once!
*/
static void initXLAGenVector() {
// Ensures we only call deviceCount only once.
static bool num_xla_device_init_flag [[maybe_unused]] = []() {
// Get local num of XLA devices
auto maybe_client = torch_xla::runtime::GetComputationClient();
if (!maybe_client.ok()) {
// If runtime client initialization failed, default to 1 device
num_xla_devices = 1;
} else {
auto* client = maybe_client.value();
num_xla_devices = static_cast<int64_t>(client->GetNumDevices());
}
xla_gens_init_flag.resize(num_xla_devices);
default_gens_xla.resize(num_xla_devices);
return true;
}();
}

} // anonymous namespace

/**
* PyTorch maintains a collection of default generators that get
* initialized once. The purpose of these default generators is to
* maintain a global running state of the pseudo random number generation,
* when a user does not explicitly mention any generator.
* getDefaultXLAGenerator gets the default generator for a particular
* XLA device.
*/
const at::Generator& getDefaultXLAGenerator(c10::DeviceIndex device_index) {
initXLAGenVector();
c10::DeviceIndex idx = device_index;
if (idx == -1) {
idx = 0; // Default to device 0 for XLA
} else {
TORCH_CHECK(idx >= 0 && idx < num_xla_devices);
}
c10::call_once(xla_gens_init_flag[idx], [&] {
default_gens_xla[idx] = at::make_generator<XLAGeneratorImpl>(idx);
default_gens_xla[idx].seed();
});
return default_gens_xla[idx];
}

/**
* Utility to create a XLAGeneratorImpl. Returns a shared_ptr
*/
at::Generator createXLAGenerator(c10::DeviceIndex device_index) {
initXLAGenVector();
c10::DeviceIndex idx = device_index;
if (idx == -1) {
idx = torch_xla::bridge::GetCurrentAtenDevice()
.index(); // Use current XLA device
}
TORCH_CHECK(idx >= 0 && idx < num_xla_devices,
"The device_index is invalid.");
auto gen = at::make_generator<XLAGeneratorImpl>(idx);
auto xla_gen = at::check_generator<XLAGeneratorImpl>(gen);
xla_gen->set_current_seed(c10::default_rng_seed_val);
return gen;
}

} // namespace detail
} // namespace at

namespace at {

Expand Down
11 changes: 10 additions & 1 deletion torch_xla/csrc/xla_generator.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

#include <ATen/core/Generator.h>
#include <ATen/core/Tensor.h>
#include <c10/core/Device.h>
#include <c10/core/DeviceType.h>
#include <c10/util/intrusive_ptr.h>

#include <cstdint>
Expand Down Expand Up @@ -53,4 +55,11 @@ struct TORCH_API XLAGeneratorImpl : public c10::GeneratorImpl {
c10::intrusive_ptr<XLAGeneratorState> state_;
};

} // namespace at
namespace detail {

const at::Generator& getDefaultXLAGenerator(c10::DeviceIndex device_index = -1);
at::Generator createXLAGenerator(c10::DeviceIndex device_index = -1);

} // namespace detail

} // namespace at
Loading