Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions torch_xla/csrc/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ ptxla_cc_library(
":status",
":tensor",
":version",
":xla_hooks",
"//torch_xla/csrc/runtime",
"//torch_xla/csrc/runtime:pjrt_computation_client",
"//torch_xla/csrc/runtime:metrics",
Expand Down Expand Up @@ -374,3 +375,21 @@ cc_library(
"@com_google_absl//absl/status:statusor",
],
)

ptxla_cc_library(
name = "xla_hooks",
srcs = [
"xla_hooks.cpp",
],
hdrs = [
"xla_hooks.h",
],
deps = [
"//torch_xla/csrc:device",
"//torch_xla/csrc:tensor",
"//torch_xla/csrc/runtime:computation_client",
"//torch_xla/csrc/runtime",
"//torch_xla/csrc/runtime:xla_util",
],
)

98 changes: 98 additions & 0 deletions torch_xla/csrc/xla_hooks.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
#include "xla_hooks.h"

#include <iostream>
#include <sstream>

// PyTorch integration headers
#include <ATen/core/Generator.h>
#include <ATen/detail/XLAHooksInterface.h>
#include <c10/core/Device.h>
#include <c10/core/DeviceType.h>
#include <c10/util/Exception.h>
#include <c10/util/Logging.h>
#include <c10/util/intrusive_ptr.h>

// XLA headers
#include "torch_xla/csrc/aten_xla_bridge.h"
#include "torch_xla/csrc/runtime/computation_client.h"
#include "torch_xla/csrc/runtime/debug_macros.h"
#include "torch_xla/csrc/runtime/runtime.h"
#include "xla_backend_impl.h"
#include "xla_generator.h"

namespace torch_xla::detail {

void XLAHooks::init() const {
C10_LOG_API_USAGE_ONCE("aten.init.xla");

// Initialize XLA backend - this registers XLA functions and sets up
// the backend infrastructure
torch_xla::InitXlaBackend();
}

bool XLAHooks::hasXLA() const { return isAvailable(); }

bool XLAHooks::isAvailable() const {
try {
return deviceCount() > 0;
} catch (...) {
// If device enumeration fails, XLA is not available
return false;
}
}

std::string XLAHooks::showConfig() const {
std::ostringstream oss;
oss << "XLA Backend Configuration:\n";
oss << " - XLA devices available: " << deviceCount() << "\n";
return oss.str();
}

c10::DeviceIndex XLAHooks::deviceCount() const {
auto maybe_client = torch_xla::runtime::GetComputationClient();
if (!maybe_client.ok()) {
// If runtime client initialization failed, return 0 devices
return 0;
}

auto* client = maybe_client.value();
return static_cast<c10::DeviceIndex>(client->GetNumDevices());
}

c10::DeviceIndex XLAHooks::getCurrentDevice() const {
return bridge::GetCurrentAtenDevice().index();
}

bool XLAHooks::hasPrimaryContext(c10::DeviceIndex device_index) const {
TORCH_CHECK(false, "hasPrimaryContext is not implemented.");
}

bool XLAHooks::isPinnedPtr(const void* data) const {
TORCH_CHECK(false, "isPinnedPtr is not implemented.");
}

c10::Allocator* XLAHooks::getPinnedMemoryAllocator() const {
TORCH_CHECK(false, "getPinnedMemoryAllocator is not implemented.");
}

c10::Device XLAHooks::getDeviceFromPtr(void* data) const {
TORCH_CHECK(false, "getDeviceFromPtr is not implemented.");
}

const at::Generator& XLAHooks::getDefaultGenerator(
c10::DeviceIndex device_index) const {
return at::detail::getDefaultXLAGenerator(device_index);
}

at::Generator XLAHooks::getNewGenerator(c10::DeviceIndex device_index) const {
// Create and return a new XLA generator using the make_generator template
// function
return at::make_generator<at::XLAGeneratorImpl>(device_index);
}

} // namespace torch_xla::detail

// Register XLA hooks with PyTorch on module load
namespace at {
REGISTER_XLA_HOOKS(torch_xla::detail::XLAHooks)
} // namespace at
42 changes: 42 additions & 0 deletions torch_xla/csrc/xla_hooks.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#pragma once

#include <string>

// PyTorch integration headers
#include <ATen/core/Generator.h>
#include <ATen/detail/XLAHooksInterface.h>
#include <c10/core/Device.h>
#include <c10/core/DeviceType.h>
#include <c10/util/Exception.h>

namespace torch_xla::detail {

// XLA hooks implementation following PyTorch patterns
struct XLAHooks : public at::XLAHooksInterface {
XLAHooks(const at::XLAHooksArgs& args) {}

// Core accelerator interface methods
void init() const override;
bool hasXLA() const override;
bool isAvailable() const override;
bool isBuilt() const override { return true; }
std::string showConfig() const override;

// Device management
c10::DeviceIndex deviceCount() const override;
c10::DeviceIndex getCurrentDevice() const override;
bool hasPrimaryContext(c10::DeviceIndex device_index) const override;

// Memory management
bool isPinnedPtr(const void* data) const override;
c10::Allocator* getPinnedMemoryAllocator() const override;
c10::Device getDeviceFromPtr(void* data) const override;

// Generator methods
const at::Generator& getDefaultGenerator(
c10::DeviceIndex device_index = -1) const override;
at::Generator getNewGenerator(
c10::DeviceIndex device_index = -1) const override;
};

} // namespace torch_xla::detail
Loading