Skip to content

Add Randomized SVDs#2999

Open
Intron7 wants to merge 4 commits intorapidsai:mainfrom
Intron7:randomized_svds
Open

Add Randomized SVDs#2999
Intron7 wants to merge 4 commits intorapidsai:mainfrom
Intron7:randomized_svds

Conversation

@Intron7
Copy link
Copy Markdown

@Intron7 Intron7 commented Apr 9, 2026

This PR adds randomized SVDs to raft based on (Halko et al. 2009) and (Tomás et al. 2024). I also added the possibility for a very limited linear operator. This one is C++ and might be useful for sparse PCA in cuml. It's tested in the C++ layer but not exposed in Python. This PR mimics the implementation I did for rapids-singlecell

@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot bot commented Apr 9, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@aamijar aamijar added non-breaking Non-breaking change feature request New feature or request labels Apr 10, 2026
@aamijar aamijar changed the title add randomized svds Add randomized svds Apr 10, 2026
@aamijar aamijar changed the title Add randomized svds Add Randomized SVDs Apr 10, 2026
@aamijar
Copy link
Copy Markdown
Member

aamijar commented Apr 10, 2026

/ok to test 307d18e

@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot bot commented Apr 10, 2026

/ok to test 307d18e

@aamijar, there was an error processing your request: E2

See the following link for more information: https://docs.gha-runners.nvidia.com/cpr/e/2/

@aamijar
Copy link
Copy Markdown
Member

aamijar commented Apr 10, 2026

/ok to test e52ad48

Copy link
Copy Markdown
Contributor

@viclafargue viclafargue left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Here are some comments.

Comment on lines +128 to +129
indptr = indptr.astype(np.int32)
indices = indices.astype(np.int32)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't there an integer overflow risk here? I think we should at least warn the user that the indices will be converted.

template <typename ValueTypeT>
struct sparse_svd_config {
/** @brief Number of singular values/vectors to compute */
int n_components;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should maybe set a default for n_components. A C++ user that forgets to specify a value would face an undefined behavior. Setting it to 0 would allow it to be catched immediately by parameters validation.

Comment on lines +40 to +41
cdef sparse_svd_config[float] config_float
cdef sparse_svd_config[double] config_double
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using module-level global config is not thread-safe. Please instantiate them in the function where they are needed.

void sparse_randomized_svd(
raft::resources const& handle,
sparse_svd_config<ValueTypeT> const& config,
raft::device_csr_matrix_view<ValueTypeT, int, int, NNZTypeT> A,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
raft::device_csr_matrix_view<ValueTypeT, int, int, NNZTypeT> A,
raft::device_csr_matrix_view<const ValueTypeT, int, int, NNZTypeT> A,

The input matrix should be const.

Also for the doc, we sometimes use the paradigm @param [in], @param [inout], @param [out]. @param [in] is supposed to be usable with const data.

Omega.data_handle(), n, block_size),
Y.view());
} // Omega freed here
cholesky_qr2(handle, Y.view());
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should maybe check the return value of cholesky_qr2 calls and emit a warning when there's a fallback to standard QR.

Comment on lines +54 to +55
int rows() const { return m_; }
int cols() const { return n_; }
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
int rows() const { return m_; }
int cols() const { return n_; }
int rows() const { return A_.structure_view().get_n_rows(); }
int cols() const { return A_.structure_view().get_n_cols(); }

Comment on lines +29 to +31
raft::device_csr_matrix_view<const ValueTypeT, int, int, NNZTypeT> A_;
int m_;
int n_;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A_ already stores the dimensions of the matrix, storing m_ and n_ here is redundant, adds complexity and introduces the possibility of bugs. The number of rows and cols can be derived directly from A_. Also A_ should probably be a private member.

void sparse_randomized_svd(
raft::resources const& handle,
sparse_svd_config<ValueTypeT> const& config,
OperatorT const& op,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even though the csr_linear_operator struct is certainly useful, the user should not ever need to construct this exotic type. The public function should expose an interface with a raft::device_csr_matrix_view and the csr_linear_operator utility should be built inside of the function.

Comment on lines +138 to +140
rmm::device_uvector<ValueTypeT> Q_copy(m * k, stream);
raft::copy(Q_copy.data(), Q.data_handle(), m * k, stream);
raft::linalg::qrGetQ(handle, Q_copy.data(), Q.data_handle(), m, k, stream);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
rmm::device_uvector<ValueTypeT> Q_copy(m * k, stream);
raft::copy(Q_copy.data(), Q.data_handle(), m * k, stream);
raft::linalg::qrGetQ(handle, Q_copy.data(), Q.data_handle(), m, k, stream);
raft::linalg::qrGetQ(handle, Q.data_handle(), Q.data_handle(), m, k, stream);

qrGetQ already does a copy internally. Using Q.data_handle() twice would allow the operation to work inplace (even the copy could be avoided as src==dst). To double check though.

Additionally, m * k has an integer overflow risk.

Comment on lines +154 to +156
rmm::device_uvector<ValueTypeT> Q_copy(m * k, stream);
raft::copy(Q_copy.data(), Q.data_handle(), m * k, stream);
raft::linalg::qrGetQ(handle, Q_copy.data(), Q.data_handle(), m, k, stream);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar comment here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

feature request New feature or request non-breaking Non-breaking change

Projects

Status: No status

Development

Successfully merging this pull request may close these issues.

3 participants