Skip to content

Commit 0242a6e

Browse files
committed
rebase
Signed-off-by: junq <[email protected]>
2 parents 018b022 + 504bb7f commit 0242a6e

File tree

3,698 files changed

+23624
-12149
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

3,698 files changed

+23624
-12149
lines changed

benchmarks/cpp/disaggServerBenchmark.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -542,7 +542,8 @@ texec::Request makeExecutorContextRequest(Sample const& sample, SizeType32 const
542542
std::nullopt, // kvCacheRetentionConfig
543543
std::nullopt, // logitsPostProcessorName
544544
std::nullopt, // logitsPostProcessor
545-
encoderInputTokenIds.has_value() ? encoderInputTokenIds : std::nullopt);
545+
encoderInputTokenIds.has_value() ? encoderInputTokenIds : std::nullopt,
546+
std::nullopt); // cacheSaltID
546547
request.setRequestType(tensorrt_llm::executor::RequestType::REQUEST_TYPE_CONTEXT_ONLY);
547548
return request;
548549
}

benchmarks/cpp/gptManagerBenchmark.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -837,7 +837,8 @@ texec::Request makeExecutorRequest(Sample const& sample, SizeType32 const& beamW
837837
std::nullopt, // kvCacheRetentionConfig
838838
std::nullopt, // logitsPostProcessorName
839839
std::nullopt, // logitsPostProcessor
840-
encoderInputTokenIds.has_value() ? encoderInputTokenIds : std::nullopt);
840+
encoderInputTokenIds.has_value() ? encoderInputTokenIds : std::nullopt,
841+
std::nullopt); // cacheSaltID
841842
}
842843

843844
void benchmarkExecutor(std::optional<std::filesystem::path> const& decoderEngineDir,

cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ using UniqueToken = tensorrt_llm::runtime::UniqueToken;
6969
using VecUniqueTokens = tensorrt_llm::runtime::VecUniqueTokens;
7070
using LoraTaskIdType = tensorrt_llm::runtime::LoraTaskIdType;
7171
using BlocksPerWindow = std::map<SizeType32, std::tuple<SizeType32, SizeType32>>;
72+
using CacheSaltIDType = tensorrt_llm::runtime::CacheSaltIDType;
7273

7374
// Type alias for multimodal hash key (hash array + start offset)
7475
using MmKey = std::pair<std::array<uint8_t, 32>, SizeType32>;
@@ -115,6 +116,7 @@ struct BlockKey
115116
// Extra keys for multimodal data (similar to VLLM's approach)
116117
// Each extra key is a pair of (mm_hash, start_offset_in_block)
117118
std::vector<MmKey> extraKeys;
119+
std::optional<CacheSaltIDType> cacheSaltID = std::nullopt;
118120

119121
BlockKey() = default;
120122

@@ -129,24 +131,25 @@ struct BlockKey
129131
}
130132

131133
explicit BlockKey(bool usesExtraIds, std::optional<LoraTaskIdType> loraTaskId, VecUniqueTokens uniqueTokens,
132-
std::vector<MmKey> extraKeys = {})
134+
std::vector<MmKey> extraKeys = {}, std::optional<CacheSaltIDType> cacheSaltID = std::nullopt)
133135
: usesExtraIds{usesExtraIds}
134136
, loraTaskId{loraTaskId}
135137
, uniqueTokens{std::move(uniqueTokens)}
136138
, extraKeys{std::move(extraKeys)}
139+
, cacheSaltID{cacheSaltID}
137140
{
138141
}
139142

140143
bool operator==(BlockKey const& other) const noexcept
141144
{
142145
return (usesExtraIds == other.usesExtraIds && loraTaskId == other.loraTaskId
143-
&& uniqueTokens == other.uniqueTokens && extraKeys == other.extraKeys);
146+
&& uniqueTokens == other.uniqueTokens && extraKeys == other.extraKeys && cacheSaltID == other.cacheSaltID);
144147
}
145148

146149
int partialMatch(BlockKey const& other) const noexcept
147150
{
148151
SizeType32 numMatched{0};
149-
if (loraTaskId == other.loraTaskId && extraKeys == other.extraKeys)
152+
if (loraTaskId == other.loraTaskId && extraKeys == other.extraKeys && cacheSaltID == other.cacheSaltID)
150153
{
151154
auto [matchEnd, otherMatchEnd] = std::mismatch(
152155
uniqueTokens.begin(), uniqueTokens.end(), other.uniqueTokens.begin(), other.uniqueTokens.end());

cpp/include/tensorrt_llm/batch_manager/llmRequest.h

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,8 @@ class GenericLlmRequest
100100
RequestIdType, TensorPtr&, BeamTokens const&, TStream const&, std::optional<RequestIdType>)>;
101101
using RequestPtr = std::shared_ptr<GenericLlmRequest>;
102102
using MillisecondsType = std::chrono::milliseconds;
103+
using CacheSaltIDType = runtime::CacheSaltIDType;
103104

104-
// 49 parameters, 56 items in initialization list
105105
GenericLlmRequest(RequestIdType requestId, SizeType32 maxNewTokens, std::shared_ptr<VecTokens> const& inputTokens,
106106
runtime::SamplingConfig const& samplingConfig, bool isStreaming, std::optional<SizeType32> endId = std::nullopt,
107107
std::optional<SizeType32> padId = std::nullopt, std::optional<TensorPtr> embeddingBias = std::nullopt,
@@ -137,7 +137,8 @@ class GenericLlmRequest
137137
std::optional<executor::GuidedDecodingParams> guidedDecodingParams = std::nullopt,
138138
std::optional<SizeType32> languageAdapterUid = std::nullopt,
139139
std::optional<MillisecondsType> allottedTimeMs = std::nullopt,
140-
std::optional<executor::ContextPhaseParams> const& contextPhaseParams = std::nullopt)
140+
std::optional<executor::ContextPhaseParams> const& contextPhaseParams = std::nullopt,
141+
std::optional<CacheSaltIDType> cacheSaltID = std::nullopt)
141142
: mRequestId(requestId)
142143
, mPromptLen(inputTokens->size())
143144
, mMaxNewTokens(maxNewTokens)
@@ -194,6 +195,7 @@ class GenericLlmRequest
194195
, mGuidedDecodingParams(std::move(guidedDecodingParams))
195196
, mLanguageAdapterUid(languageAdapterUid)
196197
, mAllottedTimeMs(allottedTimeMs)
198+
, mCacheSaltID(cacheSaltID)
197199
{
198200
if (mEncoderTokens.has_value() || encoderInputFeatures.has_value())
199201
{
@@ -203,7 +205,6 @@ class GenericLlmRequest
203205
initialize(*inputTokens, returnLogProbs);
204206
}
205207

206-
// 32 parameters, 39 items in initialization list
207208
GenericLlmRequest(RequestIdType requestId, SizeType32 maxNewTokens, VecTokens const& inputTokens,
208209
runtime::SamplingConfig const& samplingConfig, bool isStreaming, std::optional<SizeType32> endId = std::nullopt,
209210
std::optional<SizeType32> padId = std::nullopt, std::optional<TensorPtr> embeddingBias = std::nullopt,
@@ -221,7 +222,8 @@ class GenericLlmRequest
221222
bool returnEncoderOutput = false, std::optional<RequestIdType> clientId = std::nullopt,
222223
executor::PriorityType priority = executor::Request::kDefaultPriority, SizeType32 numReturnSequences = 1,
223224
std::optional<SizeType32> languageAdapterUid = std::nullopt,
224-
std::optional<executor::ContextPhaseParams> const& contextPhaseParams = std::nullopt)
225+
std::optional<executor::ContextPhaseParams> const& contextPhaseParams = std::nullopt,
226+
std::optional<CacheSaltIDType> cacheSaltID = std::nullopt)
225227
: mRequestId(requestId)
226228
, mPromptLen(inputTokens.size())
227229
, mMaxNewTokens(maxNewTokens)
@@ -261,6 +263,7 @@ class GenericLlmRequest
261263
, mContextPhaseParams(contextPhaseParams)
262264
, mNumReturnSequences(numReturnSequences)
263265
, mLanguageAdapterUid(languageAdapterUid)
266+
, mCacheSaltID(cacheSaltID)
264267
{
265268
if (mEncoderTokens.has_value())
266269
{
@@ -269,7 +272,6 @@ class GenericLlmRequest
269272
initialize(inputTokens, returnLogProbs);
270273
}
271274

272-
// 29 items in initialization list
273275
GenericLlmRequest(RequestIdType requestId, executor::Request const& req)
274276
: mRequestId(requestId)
275277
, mPromptLen(req.getInputTokenIds().size())
@@ -300,6 +302,7 @@ class GenericLlmRequest
300302
, mGuidedDecodingParams(req.getGuidedDecodingParams())
301303
, mLanguageAdapterUid(req.getLanguageAdapterUid())
302304
, mAllottedTimeMs(req.getAllottedTimeMs())
305+
, mCacheSaltID(req.getCacheSaltID())
303306
{
304307
if (req.getRequestType() == executor::RequestType::REQUEST_TYPE_GENERATION_ONLY)
305308
{
@@ -1764,6 +1767,11 @@ class GenericLlmRequest
17641767
return mLanguageAdapterUid;
17651768
}
17661769

1770+
[[nodiscard]] std::optional<CacheSaltIDType> getCacheSaltID() const
1771+
{
1772+
return mCacheSaltID;
1773+
}
1774+
17671775
std::vector<SizeType32> getLanguageAdapterRouting(
17681776
SizeType32 const reqNumLanguages, SizeType32 const inputLength) const
17691777
{
@@ -2042,6 +2050,9 @@ class GenericLlmRequest
20422050

20432051
bool mUseDraftModel{false};
20442052

2053+
// Cache salt id for each request.
2054+
std::optional<CacheSaltIDType> mCacheSaltID{std::nullopt};
2055+
20452056
private:
20462057
void initialize(VecTokens const& inputTokens, bool outputLogProbs)
20472058
{
@@ -2222,7 +2233,8 @@ class LlmRequest : public GenericLlmRequest<runtime::ITensor::SharedPtr>
22222233
std::optional<executor::GuidedDecodingParams> guidedDecodingParams = std::nullopt,
22232234
std::optional<SizeType32> languageAdapterUid = std::nullopt,
22242235
std::optional<MillisecondsType> allottedTimeMs = std::nullopt,
2225-
std::optional<executor::ContextPhaseParams> const& contextPhaseParams = std::nullopt)
2236+
std::optional<executor::ContextPhaseParams> const& contextPhaseParams = std::nullopt,
2237+
std::optional<CacheSaltIDType> cacheSaltID = std::nullopt)
22262238
: Base(requestId, maxNewTokens, std::move(inputTokens), samplingConfig, isStreaming, endId, padId,
22272239
std::move(embeddingBias), std::move(badWordsList), std::move(stopWordsList), std::move(positionIds),
22282240
std::move(promptEmbeddingTable), promptVocabSize, std::move(multimodalHashes),
@@ -2234,7 +2246,8 @@ class LlmRequest : public GenericLlmRequest<runtime::ITensor::SharedPtr>
22342246
std::move(encoderInputTokens), returnEncoderOutput, clientId, priority, std::move(encoderInputFeatures),
22352247
std::move(encoderOutputLength), std::move(crossAttentionMask), llmRequestType,
22362248
std::move(inputTokenExtraIds), numReturnSequences, std::move(eagleConfig), std::move(skipCrossAttnBlocks),
2237-
returnPerfMetrics, std::move(guidedDecodingParams), languageAdapterUid, allottedTimeMs, contextPhaseParams)
2249+
returnPerfMetrics, std::move(guidedDecodingParams), languageAdapterUid, allottedTimeMs, contextPhaseParams,
2250+
cacheSaltID)
22382251
{
22392252
}
22402253

@@ -2272,7 +2285,8 @@ class LlmRequest : public GenericLlmRequest<runtime::ITensor::SharedPtr>
22722285
std::optional<executor::GuidedDecodingParams> guidedDecodingParams = std::nullopt,
22732286
std::optional<SizeType32> languageAdapterUid = std::nullopt,
22742287
std::optional<MillisecondsType> allottedTimeMs = std::nullopt,
2275-
std::optional<executor::ContextPhaseParams> const& contextPhaseParams = std::nullopt)
2288+
std::optional<executor::ContextPhaseParams> const& contextPhaseParams = std::nullopt,
2289+
std::optional<CacheSaltIDType> cacheSaltID = std::nullopt)
22762290
: Base(requestId, maxNewTokens, std::make_shared<std::vector<TokenIdType>>(std::move(inputTokens)),
22772291
samplingConfig, isStreaming, endId, padId, std::move(embeddingBias), std::move(badWordsList),
22782292
std::move(stopWordsList),
@@ -2302,7 +2316,7 @@ class LlmRequest : public GenericLlmRequest<runtime::ITensor::SharedPtr>
23022316
inputTokenExtraIds ? std::make_optional(std::make_shared<VecTokenExtraIds>(std::move(*inputTokenExtraIds)))
23032317
: std::optional<std::shared_ptr<VecTokenExtraIds>>(std::nullopt),
23042318
numReturnSequences, std::move(eagleConfig), skipCrossAttnBlocks, returnPerfMetrics,
2305-
std::move(guidedDecodingParams), languageAdapterUid, allottedTimeMs, contextPhaseParams)
2319+
std::move(guidedDecodingParams), languageAdapterUid, allottedTimeMs, contextPhaseParams, cacheSaltID)
23062320
{
23072321
}
23082322

@@ -2324,14 +2338,15 @@ class LlmRequest : public GenericLlmRequest<runtime::ITensor::SharedPtr>
23242338
bool returnEncoderOutput = false, std::optional<RequestIdType> clientId = std::nullopt,
23252339
executor::PriorityType priority = executor::Request::kDefaultPriority, SizeType32 numReturnSequences = 1,
23262340
std::optional<SizeType32> languageAdapterUid = std::nullopt,
2327-
std::optional<executor::ContextPhaseParams> const& contextPhaseParams = std::nullopt)
2341+
std::optional<executor::ContextPhaseParams> const& contextPhaseParams = std::nullopt,
2342+
std::optional<CacheSaltIDType> cacheSaltID = std::nullopt)
23282343
: Base(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming, endId, padId,
23292344
std::move(embeddingBias), std::move(badWordsList), std::move(stopWordsList), std::move(positionIds),
23302345
std::move(promptEmbeddingTable), promptVocabSize, loraTaskId, std::move(loraWeights), std::move(loraConfig),
23312346
lookaheadConfig, returnLogProbs, returnContextLogits, returnGenerationLogits, std::move(draftTokens),
23322347
std::move(draftLogits), excludeInputFromOutput, std::move(logitsPostProcessor),
23332348
applyLogitsPostProcessorBatched, std::move(encoderInputTokens), returnEncoderOutput, clientId, priority,
2334-
numReturnSequences, languageAdapterUid, contextPhaseParams)
2349+
numReturnSequences, languageAdapterUid, contextPhaseParams, cacheSaltID)
23352350
{
23362351
}
23372352

cpp/include/tensorrt_llm/executor/executor.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -670,7 +670,7 @@ class Request
670670
/// @param allottedTimeMs The allotted time in milliseconds after which the request is cancelled with a timedOut
671671
/// finish reason. The request may exceed this time slightly, but at most by 1 forward pass (in pipeline parallelism
672672
/// that may involve multiple micro-batches). A request can be timed-out before ever being scheduled.
673-
// 34 parameters
673+
/// @param cacheSaltID Salt ID for KV cache blocks to limit the kv cache reuse to the requests with the same string.
674674
Request(VecTokens inputTokenIds, SizeType32 maxTokens, bool streaming = false,
675675
SamplingConfig const& samplingConfig = SamplingConfig(), OutputConfig const& outputConfig = OutputConfig(),
676676
std::optional<SizeType32> const& endId = std::nullopt, std::optional<SizeType32> const& padId = std::nullopt,
@@ -697,7 +697,8 @@ class Request
697697
std::optional<EagleConfig> eagleConfig = std::nullopt, std::optional<Tensor> skipCrossAttnBlocks = std::nullopt,
698698
std::optional<GuidedDecodingParams> guidedDecodingParams = std::nullopt,
699699
std::optional<SizeType32> languageAdapterUid = std::nullopt,
700-
std::optional<MillisecondsType> allottedTimeMs = std::nullopt);
700+
std::optional<MillisecondsType> allottedTimeMs = std::nullopt,
701+
std::optional<CacheSaltIDType> cacheSaltID = std::nullopt);
701702

702703
/// @brief This logits postprocessor name will dispatch to the batched logits postprocessor
703704
static auto constexpr kBatchedPostProcessorName = "batched";
@@ -745,6 +746,7 @@ class Request
745746
[[nodiscard]] std::optional<GuidedDecodingParams> getGuidedDecodingParams() const;
746747
[[nodiscard]] std::optional<SizeType32> getLanguageAdapterUid() const;
747748
[[nodiscard]] std::optional<MillisecondsType> getAllottedTimeMs() const;
749+
[[nodiscard]] std::optional<CacheSaltIDType> getCacheSaltID() const;
748750
[[nodiscard]] std::optional<std::vector<std::string>> getAdditionalOutputNames() const;
749751

750752
void setStreaming(bool streaming);
@@ -780,6 +782,7 @@ class Request
780782
void setGuidedDecodingParams(GuidedDecodingParams const& guidedDecodingParams);
781783
void setLanguageAdapterUid(SizeType32 languageAdapterUid);
782784
void setAllottedTimeMs(MillisecondsType allottedTimeMs);
785+
void setCacheSaltID(CacheSaltIDType cacheSaltID);
783786

784787
private:
785788
friend class Serialization;

cpp/include/tensorrt_llm/executor/types.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ using RandomSeedType = std::uint64_t;
5858
using VecLogProbs = std::vector<FloatType>;
5959
using StreamPtr = std::shared_ptr<tensorrt_llm::runtime::CudaStream>;
6060
using MillisecondsType = std::chrono::milliseconds;
61+
using CacheSaltIDType = std::uint64_t;
6162
using LogitsPostProcessor
6263
= std::function<void(IdType, Tensor&, BeamTokens const&, StreamPtr const&, std::optional<IdType>)>;
6364
using LogitsPostProcessorMap = std::unordered_map<std::string, LogitsPostProcessor>;

cpp/include/tensorrt_llm/runtime/common.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ using TokenIdType = std::int32_t;
4444
using LoraTaskIdType = std::uint64_t;
4545
using TokenExtraIdType = std::uint64_t;
4646
using VecTokenExtraIds = std::vector<TokenExtraIdType>;
47+
using CacheSaltIDType = std::uint64_t;
4748

4849
struct UniqueToken
4950
{

cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp

Lines changed: 40 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,34 @@ class DataResponder::Impl
204204
}
205205
}
206206

207+
void sendResponse(std::vector<size_t> const& blockHashes, std::map<RequestIdType, Response>::iterator it)
208+
{
209+
auto reqId = mCurrentRequest.value();
210+
auto count = --mRemainSendCount[reqId];
211+
TLLM_CHECK(count >= 0);
212+
if (count == 0)
213+
{
214+
mRemainSendCount.erase(reqId);
215+
216+
// TODO(zhengd): pass the hashes directly instead of update llmRequest
217+
auto llmRequest = it->second.mRequest;
218+
llmRequest->setRequestedBlockHashes(std::move(blockHashes));
219+
220+
if (common::getEnvParallelCacheSend())
221+
{
222+
// TODO: Use a thread pool and check for thread safety.
223+
std::thread(&DataResponder::Impl::sendAndRemoveResponse, this, it->first, std::move(it->second))
224+
.detach();
225+
}
226+
else
227+
{
228+
DataResponder::Impl::sendAndRemoveResponse(it->first, std::move(it->second));
229+
}
230+
removeResponse(it);
231+
}
232+
mCurrentRequest = std::nullopt;
233+
}
234+
207235
void response() noexcept
208236
{
209237
try
@@ -237,40 +265,22 @@ class DataResponder::Impl
237265
auto it = getCurrentResponse();
238266
if (it != mReadyResponses.end())
239267
{
240-
auto reqId = mCurrentRequest.value();
241-
auto count = --mRemainSendCount[reqId];
242-
TLLM_CHECK(count >= 0);
243-
if (count == 0)
268+
sendResponse(blockHashes, it);
269+
}
270+
else
271+
{
272+
auto it = getCurrentResponse();
273+
while (it == mReadyResponses.end())
244274
{
245-
mRemainSendCount.erase(reqId);
246-
247-
// TODO(zhengd): pass the hashes directly instead of update llmRequest
248-
auto llmRequest = it->second.mRequest;
249-
llmRequest->setRequestedBlockHashes(std::move(blockHashes));
250-
251-
if (common::getEnvParallelCacheSend())
252-
{
253-
// TODO: Use a thread pool and check for thread safety.
254-
std::thread(
255-
&DataResponder::Impl::sendAndRemoveResponse, this, it->first, std::move(it->second))
256-
.detach();
257-
}
258-
else
275+
std::unique_lock lk(mCondMutex);
276+
mResponderCv.wait(lk, [this]() { return (mAnyReady || mTerminate); });
277+
if (mTerminate)
259278
{
260-
DataResponder::Impl::sendAndRemoveResponse(it->first, std::move(it->second));
279+
break;
261280
}
262-
removeResponse(it);
281+
it = getCurrentResponse();
263282
}
264-
mCurrentRequest = std::nullopt;
265-
}
266-
else
267-
{
268-
TLLM_CHECK_WITH_INFO(!mCurrentRequest.has_value(),
269-
"This executor does not have a prepared KV cache for request ID: %zu, and the "
270-
"mReadyResponses size is: %zu. mpi rank :%d ",
271-
mCurrentRequest.value(), mReadyResponses.size(), mpi::MpiComm::world().getRank());
272-
std::unique_lock lk(mCondMutex);
273-
mResponderCv.wait(lk, [this]() { return (mAnyReady || mTerminate); });
283+
sendResponse(blockHashes, it);
274284
}
275285
}
276286
}

0 commit comments

Comments
 (0)