Skip to content

Commit 73a8c90

Browse files
committed
rebase
Signed-off-by: junq <[email protected]>
2 parents fd5d5cb + 2923eb8 commit 73a8c90

File tree

34 files changed

+1251
-424
lines changed

34 files changed

+1251
-424
lines changed

cpp/include/tensorrt_llm/batch_manager/createNewDecoderRequests.h

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -75,27 +75,19 @@ class CreateNewDecoderRequests : Algorithm
7575
std::vector<executor::LookaheadDecodingConfig>>
7676
operator()(runtime::ModelConfig const& modelConfig, runtime::WorldConfig const& worldConfig,
7777
executor::DecodingConfig const& decodingConfig, RequestVector const& contextRequests,
78-
runtime::BufferManager const& bufferManager, nvinfer1::DataType logitsType, DecoderInputBuffers& inputBuffers,
79-
runtime::decoder::DecoderState& decoderState, CudaStream const& runtimeStream, CudaStream const& decoderStream,
80-
SizeType32 maxSequenceLength, SizeType32 beamWidth, OptionalRef<MedusaBuffers const> medusaBuffers) const;
78+
nvinfer1::DataType logitsType, DecoderInputBuffers& inputBuffers, runtime::decoder::DecoderState& decoderState,
79+
CudaStream const& runtimeStream, CudaStream const& decoderStream, SizeType32 maxSequenceLength,
80+
SizeType32 beamWidth, OptionalRef<MedusaBuffers const> medusaBuffers) const;
8181

8282
[[nodiscard]] std::tuple<std::vector<runtime::ITensor::SharedConstPtr>,
8383
std::vector<executor::LookaheadDecodingConfig>>
8484
createDecoderRequests(RequestVector const& finishedContextRequests, TensorPtr const& inputIds,
8585
executor::DecodingConfig const& decodingConfig, runtime::decoder::DecoderState& decoderState,
86-
runtime::BufferManager const& bufferManager, nvinfer1::DataType logitsType,
87-
runtime::ModelConfig const& modelConfig, runtime::WorldConfig const& worldConfig,
86+
nvinfer1::DataType logitsType, runtime::ModelConfig const& modelConfig, runtime::WorldConfig const& worldConfig,
8887
runtime::CudaStream const& runtimeStream, runtime::CudaStream const& decoderStream,
8988
SizeType32 maxSequenceLength, OptionalRef<MedusaBuffers const> medusaBuffers) const;
9089

9190
private:
92-
//! @brief Initialize the decoder at `batchSlot` with a new `request`. Exposed only for static batching via
93-
//! GptDecoderBatched::newBatch()
94-
static void newRequest(SizeType32 batchSlot, runtime::decoder_batch::Request const& request,
95-
SamplingConfig const& samplingConfig, runtime::ModelConfig const& modelConfig,
96-
runtime::decoder::DecoderState& decoderState, CudaStream const& runtimeStream, CudaStream const& decoderStream,
97-
SizeType32 maxSequenceLength);
98-
9991
//! @brief Setups decoder internal tensors for new speculative decoding request
10092
static void newRequestSpeculativeDecoding(SizeType32 batchIdx, runtime::decoder_batch::Request const& request,
10193
SamplingConfig const& samplingConfig, runtime::ModelConfig const& modelConfig,

cpp/include/tensorrt_llm/runtime/decoderState.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,11 @@ class DecoderState
173173
//! @brief Workspace for beam search in streaming mode.
174174
[[nodiscard]] BeamSearchBuffers const& getBeamSearchBuffers() const;
175175

176+
//! @brief Set the beam width for a specific request in the batch.
177+
//! @param batchIdx The index of the request in the batch.
178+
//! @param beamWidth The beam width for the specified request.
179+
void setBeamWidth(SizeType32 batchIdx, SizeType32 beamWidth);
180+
176181
//! @brief Cache indirection input for beam search.
177182
[[nodiscard]] TensorPtr getCacheIndirectionInput() const;
178183

cpp/include/tensorrt_llm/runtime/request.h

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -31,26 +31,16 @@ class Request
3131
using TensorPtr = ITensor::SharedPtr;
3232
using BufferPtr = IBuffer::SharedPtr;
3333

34-
explicit Request(TensorConstPtr ids, SizeType32 inputLen, std::optional<SizeType32> maxNewTokens = std::nullopt,
35-
std::optional<SizeType32> endId = std::nullopt)
36-
: ids{std::move(ids)}
37-
, inputLen(inputLen)
38-
, maxNewTokens{maxNewTokens}
39-
, endId{endId}
34+
explicit Request(SizeType32 inputLen)
35+
: inputLen(inputLen)
4036
{
4137
}
4238

4339
//! Mandatory parameters
44-
TensorConstPtr ids; // The input sequence of token ids, [inputSeqLen], on gpu
4540
SizeType32 inputLen; // Input length without draft tokens, increasing with generation steps
4641

4742
// optional parameters
48-
std::optional<SizeType32> maxNewTokens; // maximum number of tokens to generate for this request
49-
std::optional<SizeType32> endId; // end token id
5043
SizeType32 generatedTokensPerEngineStep{1}; //
51-
TensorPtr embeddingBias; // [vocabSizePadded], on gpu
52-
TensorPtr badWordsList; // [2, badWordsLength] on gpu
53-
TensorPtr stopWordsList; // [2, stopWordsLength] on gpu
5444

5545
//! Optional parameters for speculative decoding
5646
BufferPtr draftTokens; // [generatedTokensPerEngineStep - 1] on gpu

cpp/tensorrt_llm/batch_manager/createNewDecoderRequests.cpp

Lines changed: 210 additions & 176 deletions
Large diffs are not rendered by default.

cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1866,9 +1866,9 @@ void TrtGptModelInflightBatching::setupDecoderStep(
18661866
auto const logitsType = mRuntime->getEngine().getTensorDataType("logits");
18671867

18681868
auto [batchSlots, samplingConfigs, lookaheadPrompt, lookaheadAlgoConfigs]
1869-
= (*mCreateNewDecoderRequests)(mModelConfig, mWorldConfig, mDecodingConfig, contextRequests,
1870-
mRuntime->getBufferManager(), logitsType, inputBuffers, *mDecoderState, mRuntime->getStream(),
1871-
*mDecoder->getDecoderStream(), getMaxSequenceLen(), mOperatingBeamWidth, buffers.mMedusaBuffers);
1869+
= (*mCreateNewDecoderRequests)(mModelConfig, mWorldConfig, mDecodingConfig, contextRequests, logitsType,
1870+
inputBuffers, *mDecoderState, mRuntime->getStream(), *mDecoder->getDecoderStream(), getMaxSequenceLen(),
1871+
mOperatingBeamWidth, buffers.mMedusaBuffers);
18721872

18731873
auto const localBatchSize = batchSlots->getSize();
18741874
if (localBatchSize > 0)

cpp/tensorrt_llm/nanobind/batch_manager/algorithms.cpp

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -103,23 +103,21 @@ void tensorrt_llm::nanobind::batch_manager::algorithms::initBindings(nb::module_
103103
"__call__",
104104
[](CreateNewDecoderRequests& self, tr::ModelConfig const& modelConfig, tr::WorldConfig const& worldConfig,
105105
executor::DecodingConfig const& decodingConfig, RequestVector const& contextRequests,
106-
tr::BufferManager const& bufferManager, nvinfer1::DataType logitsType,
107-
DecoderInputBuffers& inputBuffers, runtime::decoder::DecoderState& decoderState,
108-
tensorrt_llm::runtime::CudaStream const& runtimeStream,
106+
nvinfer1::DataType logitsType, DecoderInputBuffers& inputBuffers,
107+
runtime::decoder::DecoderState& decoderState, tensorrt_llm::runtime::CudaStream const& runtimeStream,
109108
tensorrt_llm::runtime::CudaStream const& decoderStream, SizeType32 maxSequenceLength,
110109
SizeType32 beamWidth)
111110
{
112111
OptionalRef<MedusaBuffers const> medusaBuffers = std::nullopt;
113-
auto [batchSlots, samplingConfigs, lookaheadPrompt, lookaheadAlgoConfigs] = self(modelConfig,
114-
worldConfig, decodingConfig, contextRequests, bufferManager, logitsType, inputBuffers, decoderState,
115-
runtimeStream, decoderStream, maxSequenceLength, beamWidth, medusaBuffers);
112+
auto [batchSlots, samplingConfigs, lookaheadPrompt, lookaheadAlgoConfigs]
113+
= self(modelConfig, worldConfig, decodingConfig, contextRequests, logitsType, inputBuffers,
114+
decoderState, runtimeStream, decoderStream, maxSequenceLength, beamWidth, medusaBuffers);
116115

117116
return std::tuple{runtime::Torch::tensor(batchSlots), std::move(samplingConfigs),
118117
std::move(lookaheadPrompt), std::move(lookaheadAlgoConfigs)};
119118
},
120119
nb::arg("model_config"), nb::arg("world_config"), nb::arg("decoding_config"), nb::arg("context_requests"),
121-
nb::arg("buffer_manager"), nb::arg("logits_type"), nb::arg("decoder_input_buffers"),
122-
nb::arg("decoder_state"), nb::arg("runtime_stream"), nb::arg("decoder_stream"),
123-
nb::arg("max_sequence_length"), nb::arg("beam_width"))
120+
nb::arg("logits_type"), nb::arg("decoder_input_buffers"), nb::arg("decoder_state"),
121+
nb::arg("runtime_stream"), nb::arg("decoder_stream"), nb::arg("max_sequence_length"), nb::arg("beam_width"))
124122
.def("name", [](CreateNewDecoderRequests const&) { return CreateNewDecoderRequests::name; });
125123
}

cpp/tensorrt_llm/pybind/batch_manager/algorithms.cpp

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -105,23 +105,21 @@ void tensorrt_llm::pybind::batch_manager::algorithms::initBindings(pybind11::mod
105105
"__call__",
106106
[](CreateNewDecoderRequests& self, tr::ModelConfig const& modelConfig, tr::WorldConfig const& worldConfig,
107107
executor::DecodingConfig const& decodingConfig, RequestVector const& contextRequests,
108-
tr::BufferManager const& bufferManager, nvinfer1::DataType logitsType,
109-
DecoderInputBuffers& inputBuffers, runtime::decoder::DecoderState& decoderState,
110-
tensorrt_llm::runtime::CudaStream const& runtimeStream,
108+
nvinfer1::DataType logitsType, DecoderInputBuffers& inputBuffers,
109+
runtime::decoder::DecoderState& decoderState, tensorrt_llm::runtime::CudaStream const& runtimeStream,
111110
tensorrt_llm::runtime::CudaStream const& decoderStream, SizeType32 maxSequenceLength,
112111
SizeType32 beamWidth)
113112
{
114113
OptionalRef<MedusaBuffers const> medusaBuffers = std::nullopt;
115-
auto [batchSlots, samplingConfigs, lookaheadPrompt, lookaheadAlgoConfigs] = self(modelConfig,
116-
worldConfig, decodingConfig, contextRequests, bufferManager, logitsType, inputBuffers, decoderState,
117-
runtimeStream, decoderStream, maxSequenceLength, beamWidth, medusaBuffers);
114+
auto [batchSlots, samplingConfigs, lookaheadPrompt, lookaheadAlgoConfigs]
115+
= self(modelConfig, worldConfig, decodingConfig, contextRequests, logitsType, inputBuffers,
116+
decoderState, runtimeStream, decoderStream, maxSequenceLength, beamWidth, medusaBuffers);
118117

119118
return std::tuple{runtime::Torch::tensor(batchSlots), std::move(samplingConfigs),
120119
std::move(lookaheadPrompt), std::move(lookaheadAlgoConfigs)};
121120
},
122121
py::arg("model_config"), py::arg("world_config"), py::arg("decoding_config"), py::arg("context_requests"),
123-
py::arg("buffer_manager"), py::arg("logits_type"), py::arg("decoder_input_buffers"),
124-
py::arg("decoder_state"), py::arg("runtime_stream"), py::arg("decoder_stream"),
125-
py::arg("max_sequence_length"), py::arg("beam_width"))
122+
py::arg("logits_type"), py::arg("decoder_input_buffers"), py::arg("decoder_state"),
123+
py::arg("runtime_stream"), py::arg("decoder_stream"), py::arg("max_sequence_length"), py::arg("beam_width"))
126124
.def("name", [](CreateNewDecoderRequests const&) { return CreateNewDecoderRequests::name; });
127125
}

cpp/tensorrt_llm/runtime/decoderState.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -644,6 +644,11 @@ void DecoderState::setGenerationSteps(std::vector<SizeType32> const& generationS
644644
mJointDecodingInput->generationSteps = generationSteps;
645645
}
646646

647+
void DecoderState::setBeamWidth(SizeType32 batchIdx, SizeType32 beamWidth)
648+
{
649+
mJointDecodingInput->beamWidths.at(batchIdx) = beamWidth;
650+
}
651+
647652
DecodingInput& DecoderState::getJointDecodingInput() const
648653
{
649654
return *mJointDecodingInput;

cpp/tests/runtime/gptDecoderBatchedTest.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -104,15 +104,14 @@ void newRequests(std::vector<std::shared_ptr<tb::LlmRequest>> const& requests, T
104104
SizeType32 maxSequenceLength, tb::DecoderInputBuffers& inputBuffers, decoder::DecoderState& decoderState)
105105
{
106106
auto const& decoderStream = *decoder.getDecoderStream();
107-
auto const bufferManager = BufferManager{std::make_shared<CudaStream>(runtimeStream.get())};
108107

109108
auto batchSlotsRange = BufferRange<SizeType32>(*batchSlots);
110109
auto const localBatchSize = batchSlots->getSize();
111110

112111
tb::CreateNewDecoderRequests createNewDecoderRequests(false, false, false);
113-
auto [lookaheadPrompt, lookaheadAlgoConfigs] = createNewDecoderRequests.createDecoderRequests(requests,
114-
inputBuffers.inputsIds, decodingConfig, decoderState, bufferManager, logitsType, modelConfig, worldConfig,
115-
runtimeStream, decoderStream, maxSequenceLength, std::nullopt);
112+
auto [lookaheadPrompt, lookaheadAlgoConfigs]
113+
= createNewDecoderRequests.createDecoderRequests(requests, inputBuffers.inputsIds, decodingConfig, decoderState,
114+
logitsType, modelConfig, worldConfig, runtimeStream, decoderStream, maxSequenceLength, std::nullopt);
116115

117116
std::vector<SamplingConfig> samplingConfigs;
118117
samplingConfigs.reserve(requests.size());

examples/auto_deploy/build_and_run_ad.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,6 @@ class PromptConfig(BaseModel):
4141
"In simple words and in a single sentence, explain the concept of gravity: ",
4242
"How to fix slicing in golf? ",
4343
"Where is the capital of Iceland? ",
44-
"How big is the universe? ",
45-
"In simple words and in a single sentence, explain the concept of gravity: ",
46-
"How to fix slicing in golf? ",
47-
"Where is the capital of Iceland? ",
4844
]
4945
)
5046
sp_kwargs: Dict[str, Any] = Field(

0 commit comments

Comments
 (0)