Skip to content

[RFC]: Enable libtorch-ABI-stable vLLM cuda wheels #26946

@janeyx99

Description

@janeyx99

Motivation.

PyTorch is building out a limited ABI through which custom ops writers can build extensions that build with one torch version and run on another. vLLM is an important part of the ecosystem that depends on torch. We'd like to help enable building vLLM wheels that would not be tied to a specific torch version to promote build system simplicity and general developer experience.

Proposed Change.

As PyTorch increases our ABI surface, we'd love vLLM's collaboration in testing out our stable APIs. We've gone through the repo and gathered the following torch APIs that we are working to make stable. The hope is that once we have the stable versions of these ops, we can build vLLM cuda wheels in a libtorch stable manner, similar to what we did in FA3 here: Dao-AILab/flash-attention#1791

The APIs:

  • Various TORCH_LIBRARY registrations [link]
  • at::Tag [link]
  • Device stuff
    • torch::kDevice
    • at::Device
    • t.device().type()
    • Get_device
  • SymInt [link]
    • We're wondering if we can replace these with ints and use Python meta functions to support compile
  • Tensor[] [link] and Int[] [link]
  • Various AT_DISPATCH macros [link]
    • AT_DISPATCH_SWITCH, AT_DISPATCH_CASE, AT_DISPATCH_BYTE_CASE
  • at::get_num_threads
  • at::cuda::OptionalCUDAGuard
  • at::cuda::getCurrentCUDAStream()
  • torch::from_blob
  • torch::full
  • tensor.index(torch::indexing::Slice(...))
  • torch::empty(..., TensorOptions.device(...).dtype(...))
  • tensor.sum(dim)
  • Equivalent_scalar_type_v
  • Tensor.const_data_ptr

Questions for vLLM maintainers:

  1. How does the plan sound? Is libtorch ABI stability of interest?
  2. Does the list of APIs seem conclusive? Did we miss any major APIs in our lookthrough?
  3. Is there anything we can clarify?

Feedback Period.

As we're seeking an open continuous collaboration, there's no real "deadline". That said, we think it would be feasible to support vLLM CUDA kernels by PyTorch 2.10 (branch cut is slotted for early December), so we'd like to get started on this as soon as possible.

CC List.

@mikaylagawarecki @zou3519 @albanD

Any Other Things.

No response

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions