Skip to content

Preserve input memory location / dtype for NN Descent#1928

Open
jinsolp wants to merge 11 commits intorapidsai:mainfrom
jinsolp:nnd-keep-input-data-mem
Open

Preserve input memory location / dtype for NN Descent#1928
jinsolp wants to merge 11 commits intorapidsai:mainfrom
jinsolp:nnd-keep-input-data-mem

Conversation

@jinsolp
Copy link
Copy Markdown
Contributor

@jinsolp jinsolp commented Mar 18, 2026

Closes #1901

Previous Code

  • We almost always allocate device side fp16 arrays. This was for...
    • allowing wmma usage
    • allowing data modification for CosineExpanded preprocessing

Current PR Changes

  • No logical changes apart from removing dispatching fp32 input to use fp32 vs fp16 distance computation. This is removed now and will default to using the input type (e.g. keep fp32 as fp32). One exception is when compress_to_fp16=True and input type is fp32. In this case we conver to fp16 to exploit wmma.
  • Reducing redundant memory:
    • We only allocate device side arrays corresponding to input dtype if input is not device-accessible (allocate half types for fp32 if compress_to_fp16=True).
    • Remove preprocessing for CosineExpanded metric (because we don't want to allocate additional device side data arrays) and do the computation inside the calculate_metric function.

Peak memory usage Changes

  • food data (5M x 384) = 7.25GiB

  • sports data (13M x 284) = 18.55GiB

  • notice how for FP32->FP16 Device (meaning data is already on device), previous code allocates a new fp16 array, resulting in more gpu mem usage. This PR ensures that we convert to fp16 on-th-fly (resulting in the overhead in time) instead of allocating new fp16 memory for that.

performance_metrics

Performance Changes

  • Conversion Overhead: On-the-fly conversion introduces negligible overhead.
  • Cosine Metric: Now reads l2 norms inside the calculate_metric function, aligning with access pattern used by the L2 distance metric. Adds minimal overhead (e.g. previously 18.2937s VS 18.7598s for 5Mx384 data)

@jinsolp jinsolp self-assigned this Mar 18, 2026
@jinsolp jinsolp requested review from a team as code owners March 18, 2026 01:54
@jinsolp jinsolp added breaking Introduces a breaking change improvement Improves an existing functionality labels Mar 18, 2026
@jinsolp jinsolp marked this pull request as draft March 18, 2026 01:55
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot bot commented Mar 18, 2026

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@jinsolp jinsolp marked this pull request as ready for review March 20, 2026 00:15
@jinsolp jinsolp changed the title [WIP] Preserve input memory location / dtype for NN Descent Preserve input memory location / dtype for NN Descent Mar 20, 2026
size_t max_iterations;
float termination_threshold;
bool return_distances;
bool compress_to_fp16;
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should remove this. If the user wants fp16, they should just do this themselves. The problem is that this flips the ownership model (albeit it's done only temporarily, it still leads to unexpected behavior when we have to copy the data in fp16 form). Better if the user just converts this themselves. The problem is that we could offer this for every index type, but it's not really necessary when the user could just convert the d-type and call the index building process w/ it. Then they wouldn't have to deal w/ the additional copy in device memory at all.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That could be an option, but if we do so, the downstream ML algos will experience a slowdown.
For example, UMAP or HDBSCAN only supports fp32, and the users are forced to experience a 2x slowdown in the knn computation step because nn descent will always use the fp32 distance computation.
Are we okay with this?

* performance and memory usage.
* - `NND_DIST_COMP_FP16`: Use fp16 distance computation.
*
* @deprecated To be removed in 26.08. Use cuvsNNDescentIndexParams_v6 with compress_to_fp16
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the compress_to_fp16 the only thing that's different between the old API and the new one? If that's the case, I suggest we remove the compress_to_fp16 option altogether and never copy the dataset. I think setting the distance type is useful, but I don't think copying the dataset is useful.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Essentially the existing distance types (NND_DIST_COMP_x) and the new compress_to_fp16 do the same thing. I think the name could be a bit misleading, but this means "compress fp32 to fp16 to use fp16 distance computation". I'll change the name to use_fp16_dist_comp.

The reason it's changed from having 3 different dist comp options is because now the default behavior would be to use the original dtype.
Previously with the three distance computation types:

  • NND_DIST_COMP_AUTO: if fp32, dispatch to fp32 or fp16 computation depending on dim. no affect for other dtype inputs.
  • NND_DIST_COMP_FP32: force fp32 input to fp32 distance computation
  • NND_DIST_COMP_FP16: force fp32 input to fp16 distance computation

Since now we want fp32 input to always compute distance in fp32, having the AUTO and the FP32 options doesn't make sense. So I decided to use a single boolean instead to decide whether to use fp32 distance computation OR fp16 distance computation for fp32 input.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

breaking Introduces a breaking change improvement Improves an existing functionality

Projects

Status: No status

Development

Successfully merging this pull request may close these issues.

Compute distances in NN Descent kernels in native types

2 participants