Skip to content

Conversation

@iwknow
Copy link
Contributor

@iwknow iwknow commented Aug 24, 2025

Currently, the implementation of torch.Generator only support "cpu" and "cuda" device type. https://github.com/pytorch/pytorch/blob/main/torch/csrc/Generator.cpp#L55-L61

This change enables torch.Generator to support more device type by allowing any device backend to register their own generator factory through a Generator Registry. This is similar to what "DeviceGuardImpl registry" does today.

Key Changes:

New registry API:

  • Added GeneratorRegistry.h and GeneratorRegistry.cpp in c10/core/impl.
  • API supports registerGenerator(DeviceType, GeneratorFactory), unregisterGenerator(DeviceType), and getGeneratorFactory(DeviceType).
  • Uses c10::DeviceType as the key and stores a factory function returning c10::intrusive_ptrc10::GeneratorImpl.

Python/C++ integration:

  • The registry is consulted in the torch.Generator constructor path for non-CPU/CUDA devices.
  • If a factory is registered for the requested device, it constructs the appropriate generator; otherwise, raises an error.

Backend extensibility:

  • Out-of-tree backends (e.g., torch_xla, torch-directml, torch_npu) can now register their custom generator implementation at module load via a static registrar object.
    Example usage:
C++
namespace {
  struct Registrar {
    Registrar() {
      at::detail::registerGenerator(c10::DeviceType::XLA, &CreateXlaGenerator);
    }
  } registrar_instance;
}

This allows torch.Generator(device='xla') to return an XlaGeneratorImpl when the torch_xla extension is imported.

@pytorch-bot
Copy link

pytorch-bot bot commented Aug 24, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/161369

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 880d298 with merge base 8951df0 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@linux-foundation-easycla
Copy link

linux-foundation-easycla bot commented Aug 24, 2025

CLA Signed

The committers listed above are authorized under a signed CLA.

@iwknow
Copy link
Contributor Author

iwknow commented Aug 24, 2025

related issue: pytorch/xla#9159

Copy link
Collaborator

@EikanWang EikanWang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regarding out-of-tree accelerators, I think PrivateUse1 based mechanism should be the recommended option - https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/detail/PrivateUse1HooksInterface.h#L76-L77

@iwknow
Copy link
Contributor Author

iwknow commented Aug 25, 2025

hi @EikanWang Thanks for reviewing!
i don't think PrivateUse1 is viable for XLA use case if my understanding is correct.

  1. for backends like XLA, it is a first‑class backend in PyTorch with its own DispatchKey::XLA and DeviceType::XLA. It is "off-tree" simply because it is developed and maintained in a different package. This is fundamentally different from the PrivateUse1 use case where the entire backend is registered under PrivateUse1, including use "DispatchKey::PrivateUse1".
  2. Or, if you are suggesting to keep XLA and only make RNG kernels explicitly consume a PrivateUse1 generator. That wouldn't work neither. torch.device('xla') is already bound to DeviceType::XLA. You can’t rebind the name “xla” to PrivateUse1 in the same process/build. Registering a PrivateUse1 backend under the “xla” name would conflict with the existing device.

@FFFrog
Copy link
Collaborator

FFFrog commented Aug 26, 2025

@EikanWang Thanks for the mention.

@iwknow PyTorch supports registering generators for out-of-tree backends by inheriting from the AcceleratorHooksInterface class. You can refer to the below link for more information.

Therefore, it seems to me that you might need to add new files named XLAHooksInterface.h and XLAHooksInterface.cpp to internally implement XLAHooksInterface in PyTorch Repo, and then create new XLAHooks.h and XLAHooks.cpp files in the XLA Repo that internally inherit from XLAHooksInterface implements the XLAHooks custom class and registers it in PyTorch.

Comment on lines 62 to 66
} else if (c10::impl::hasGenerator(device_type)) {
self->cdata = at::Generator(c10::impl::makeGenerator(device));
} else {
throw std::runtime_error("No generator available for device type: " +
c10::toString(device_type));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, wo do not need to modify this one, all the accelerators except CPU will follow the same AcceleratorHooksInterface logic.

@soulitzer soulitzer added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Aug 26, 2025
@iwknow
Copy link
Contributor Author

iwknow commented Aug 26, 2025

@EikanWang Thanks for the mention.

@iwknow PyTorch supports registering generators for out-of-tree backends by inheriting from the AcceleratorHooksInterface class. You can refer to the below link for more information.

Therefore, it seems to me that you might need to add new files named XLAHooksInterface.h and XLAHooksInterface.cpp to internally implement XLAHooksInterface in PyTorch Repo, and then create new XLAHooks.h and XLAHooks.cpp files in the XLA Repo that internally inherit from XLAHooksInterface implements the XLAHooks custom class and registers it in PyTorch.

This is actually a brilliant idea, i will take a closer look.

@iwknow
Copy link
Contributor Author

iwknow commented Aug 30, 2025

Hi @FFFrog
I've updated the PR based on your suggestion. please review. thank you very much!

Comment on lines 47 to 49
virtual DeviceIndex getNumDevices() const {
return 0;
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this one? Can deviceCount from base class AcceleratorHooksInterface help?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch. i guess it is copy/paste from getNumGPUs XPUHooksInterface.h. removed to use DeviceIndex deviceCount from AcceleratorHooksInterface

@FFFrog
Copy link
Collaborator

FFFrog commented Aug 30, 2025

It would be better if we change the code below into return detail::getXLAHooks().hasXLA(); as well.

static bool hasXLA() {
return c10::impl::hasDeviceGuardImpl(c10::DeviceType::XLA);

@iwknow
Copy link
Contributor Author

iwknow commented Aug 31, 2025

@FFFrog Thanks for reviewing. PR updated.

@FFFrog FFFrog added the topic: not user facing topic category label Sep 1, 2025
Copy link
Collaborator

@FFFrog FFFrog left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, LGTM, let's trigger CI first.

@iwknow, I don't have valid approval permissions, so we'll have to get approval from @albanD.

Hey, @albanD, please take a look at this, thanks.

Also, it's necessary to find a better way to implement Hooks registration, otherwise every time we add a new backend to PyTorch, we'll need to add a Hooks interface for it.

@iwknow
Copy link
Contributor Author

iwknow commented Sep 3, 2025

@albanD, a kindly ping.

@iwknow
Copy link
Contributor Author

iwknow commented Sep 9, 2025

Hi @FFFrog, can you please nominate another approver as it seems that albanD is unresponsive.

@FFFrog
Copy link
Collaborator

FFFrog commented Sep 10, 2025

Hi @FFFrog, can you please nominate another approver as it seems that albanD is unresponsive.

Sure, Hey @albanD @ezyang @malfet, could you please help to take a look at this one? Thank you.

Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That sounds ok to add to make pytorch/XLA work. Do you have a sister PR on the xla side to make it work we should look at at the same time?
Also I would like one of the pytorch/XLA maintainer to comment here to make sure this is good on their end.

@FFFrog xla is a bit of a weird backend for historical reason. We don't plan on adding any new other backend that is not going through PrivateUse1 going forward indeed.

@iwknow
Copy link
Contributor Author

iwknow commented Sep 11, 2025

That sounds ok to add to make pytorch/XLA work. Do you have a sister PR on the xla side to make it work we should look at at the same time? Also I would like one of the pytorch/XLA maintainer to comment here to make sure this is good on their end.

@FFFrog xla is a bit of a weird backend for historical reason. We don't plan on adding any new other backend that is not going through PrivateUse1 going forward indeed.

Sister PR for xla: yes, i do have the code change for XLA. it's basically a concrete implementation of the interface. but it will be in pytorch/XLA package. would you like to review?

qihqi has been reviewing the generator change. I think he can take a look at this change as well. qihqi, please leave a comment if this change looks good to you and is what torch/XLA team want.

@FFFrog
Copy link
Collaborator

FFFrog commented Sep 12, 2025

@FFFrog xla is a bit of a weird backend for historical reason. We don't plan on adding any new other backend that is not going through PrivateUse1 going forward indeed.

Thank you, I got it.

@iwknow
Copy link
Contributor Author

iwknow commented Sep 25, 2025

@qihqi ping again.

@iwknow
Copy link
Contributor Author

iwknow commented Oct 14, 2025

@FFFrog @albanD can you please merge this change. thanks!

@FFFrog
Copy link
Collaborator

FFFrog commented Oct 15, 2025

@FFFrog @albanD can you please merge this change. thanks!

Maintainer approval is required before merging (you can merge on your own, see this link for more information)

@ysiraichi
Copy link
Collaborator

@iwknow Could you open a PR in PyTorch/XLA with the XLAHooks implementation, pinning this PR? I think it's good measure to check that PyTorch/XLA tests pass, given that there is no PyTorch/XLA CI in PyTorch anymore.

@iwknow
Copy link
Contributor Author

iwknow commented Oct 15, 2025

@iwknow Could you open a PR in PyTorch/XLA with the XLAHooks implementation, pinning this PR? I think it's good measure to check that PyTorch/XLA tests pass, given that there is no PyTorch/XLA CI in PyTorch anymore.

Will do shortly. meanwhile, @guangyey please review and approve.

@guangyey
Copy link
Collaborator

Hi @iwknow, I don't have permission to merge this PR. You need to get approval from @albanD

iwknow and others added 3 commits October 19, 2025 15:00
Also, rename one method in DeviceAccelerator to align with the interface.
@iwknow
Copy link
Contributor Author

iwknow commented Oct 20, 2025

@ysiraichi FYI, this is the XLA hooks implementation: pytorch/xla#9683

@albanD
Copy link
Collaborator

albanD commented Oct 20, 2025

Thanks for the linked PR, I'm not sure if something else is needed on the xla side, but happy to merge this one if it's enough.

Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Accepting here since @qihqi acccepted as well and it doesn't impact other components.

DETECT_AND_ASSIGN_ACCELERATOR_COMP(HIP)
DETECT_AND_ASSIGN_ACCELERATOR_COMP(MPS)
DETECT_AND_ASSIGN_ACCELERATOR_COMP(HPU)
DETECT_AND_ASSIGN_ACCELERATOR_COMP(XLA)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding xla here has quite deep implications FYI. It will change the behavior of many higher level systems wrt the xla device.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we add then? it seems a right thing to do to my knowledge. if you feel we shouldn't make this change, i can revert

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@albanD Could you give some examples on the behaviors that might change? Without this line, are we still going to see significant behavior changes w.r.t. PyTorch/XLA?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All the code that checks the current accelerator will now pick up xla when it wasn't before.
https://github.com/search?q=repo%3Apytorch%2Fpytorch+current_accelerator&type=code are the ones in core but there are many more out-of-core now.

Autograd will start to do stream handling for you, autocast will also have different behavior, pinned memory will start trying to use your host allocator, and out of core repos should use generic code instead of xla specific code.

Yes this concern is only about this one line (and the update below in this file). The rest of the changes are lower impact indeed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

after taking a closer look at getAcceleratorHooksInterface(

const AcceleratorHooksInterface& getAcceleratorHooksInterface(
), at::getAccelerator only makes difference when opt_device_type is doesn't have value. when i search the usage of getAcceleratorHooksInterface (https://github.com/search?q=repo%3Apytorch%2Fpytorch%20getAcceleratorHooksInterface&type=code), it turns out that it is only used in a few places and torch/csrc/autograd/engine.cpp is the only place that opt_device_type. The random number generator checks and passes opt_device_type. Therefore, it getAcceleratorHooksInterface works properly even at::getAccelerator doesn't recognize XLA.

Considering the scope of evaluating the impact of changing is at::getAccelerator is way beyond the scope of this change, i created a separate issue #166054 to track the change of at::getAccelerator and revert the related change from this PR.

@iwknow
Copy link
Contributor Author

iwknow commented Oct 22, 2025

@albanD please merge if there is no other question. thanks!

@FFFrog
Copy link
Collaborator

FFFrog commented Oct 23, 2025

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 23, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@iwknow
Copy link
Contributor Author

iwknow commented Oct 23, 2025

@FFFrog i am confused by the merging status. is the merge failed? i see Closed with unmerged commits, but it seems that changes are committed to the mainline https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/detail/XLAHooksInterface.h

@FFFrog
Copy link
Collaborator

FFFrog commented Oct 24, 2025

@FFFrog i am confused by the merging status. is the merge failed? i see Closed with unmerged commits, but it seems that changes are committed to the mainline https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/detail/XLAHooksInterface.h

Don't worry about the status of the PR, the changes have already been merged into the trunk. :D

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged open source topic: not user facing topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

10 participants