-
Notifications
You must be signed in to change notification settings - Fork 3.1k
[ROCm] Use AITER sampling implementation instead of torch native #11257
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[ROCm] Use AITER sampling implementation instead of torch native #11257
Conversation
Summary of ChangesHello @b8zhong, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request integrates the AITER sampling backend for ROCm devices, aiming to significantly boost performance over the existing PyTorch implementation. The changes include dynamic backend selection, robust error handling, and a fallback for specific sampling parameters, leading to measurable improvements in processing speed and responsiveness. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request integrates the AITER sampling implementation for ROCm devices, which provides a noticeable performance improvement as demonstrated by the benchmarks. The changes are well-structured, adding the necessary environment variables, server arguments, and backend logic with appropriate fallbacks for unsupported configurations. My review includes a few minor suggestions to improve code clarity and maintainability in the new sampling logic, such as simplifying redundant checks and making control flow more explicit.
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
a5dda36
to
b53a9e8
Compare
Motivation
Note: ROCm only change, I tested this on MI355X.
There is a sampling operators available. It's more performant than the torch native implementation. Note this does not affect the greedy case, as if we have top_p = 1 and top_k = 1, we will still use
torch.argsort
. There was another operator for this, but I found that it has some correctness issues (or TBD maybe my fault of the usage), so it was not integrated in this PR.LM-Eval result
Before
Command
lm_eval \ --model local-completions \ --tasks gsm8k_platinum \ --model_args model=amd/Llama-3.1-8B-Instruct-FP8-KV,base_url=http://localhost:30000/v1/completions \ --trust_remote_code \ --num_fewshot 8 \ --batch_size 256 \ --gen_kwargs "do_sample=True,temperature=0.7,top_p=0.9,top_k=50,max_new_tokens=256"
Results
After
Command
lm_eval \ --model local-completions \ --tasks gsm8k_platinum \ --model_args model=amd/Llama-3.1-8B-Instruct-FP8-KV,base_url=http://localhost:30000/v1/completions \ --trust_remote_code \ --num_fewshot 8 \ --batch_size 256 \ --gen_kwargs "do_sample=True,temperature=0.7,top_p=0.9,top_k=50,max_new_tokens=256"
Results
Generally, the results are inline with the existing sampling behaviour.
Benchmarking and Profiling
Before
Command
python3 -m sglang.bench_serving \ --backend sglang \ --host localhost \ --port 30000 \ --num-prompts 4096 \ --max-concurrency 64 \ --flush-cache \ --extra-request-body '{"sampling_params": {"top_p": 0.95, "top_k": 50}}'
Results
After
Command
python3 -m sglang.bench_serving \ --backend sglang \ --host localhost \ --port 30000 \ --num-prompts 4096 \ --max-concurrency 64 \ --flush-cache \ --extra-request-body '{"sampling_params": {"top_p": 0.95, "top_k": 50}}'
Results
The changes improved throughput and latency by roughly 6%. After some further testing for various top_p and top_k values in individual sampling, it outperforms by a decent margin in nearly all combinations.