@@ -100,8 +100,8 @@ class GenericLlmRequest
100100 RequestIdType, TensorPtr&, BeamTokens const &, TStream const &, std::optional<RequestIdType>)>;
101101 using RequestPtr = std::shared_ptr<GenericLlmRequest>;
102102 using MillisecondsType = std::chrono::milliseconds;
103+ using CacheSaltIDType = runtime::CacheSaltIDType;
103104
104- // 49 parameters, 56 items in initialization list
105105 GenericLlmRequest (RequestIdType requestId, SizeType32 maxNewTokens, std::shared_ptr<VecTokens> const & inputTokens,
106106 runtime::SamplingConfig const & samplingConfig, bool isStreaming, std::optional<SizeType32> endId = std::nullopt ,
107107 std::optional<SizeType32> padId = std::nullopt , std::optional<TensorPtr> embeddingBias = std::nullopt ,
@@ -137,7 +137,8 @@ class GenericLlmRequest
137137 std::optional<executor::GuidedDecodingParams> guidedDecodingParams = std::nullopt ,
138138 std::optional<SizeType32> languageAdapterUid = std::nullopt ,
139139 std::optional<MillisecondsType> allottedTimeMs = std::nullopt ,
140- std::optional<executor::ContextPhaseParams> const & contextPhaseParams = std::nullopt )
140+ std::optional<executor::ContextPhaseParams> const & contextPhaseParams = std::nullopt ,
141+ std::optional<CacheSaltIDType> cacheSaltID = std::nullopt )
141142 : mRequestId (requestId)
142143 , mPromptLen (inputTokens->size ())
143144 , mMaxNewTokens(maxNewTokens)
@@ -194,6 +195,7 @@ class GenericLlmRequest
194195 , mGuidedDecodingParams (std::move(guidedDecodingParams))
195196 , mLanguageAdapterUid (languageAdapterUid)
196197 , mAllottedTimeMs (allottedTimeMs)
198+ , mCacheSaltID (cacheSaltID)
197199 {
198200 if (mEncoderTokens .has_value () || encoderInputFeatures.has_value ())
199201 {
@@ -203,7 +205,6 @@ class GenericLlmRequest
203205 initialize (*inputTokens, returnLogProbs);
204206 }
205207
206- // 32 parameters, 39 items in initialization list
207208 GenericLlmRequest (RequestIdType requestId, SizeType32 maxNewTokens, VecTokens const & inputTokens,
208209 runtime::SamplingConfig const & samplingConfig, bool isStreaming, std::optional<SizeType32> endId = std::nullopt ,
209210 std::optional<SizeType32> padId = std::nullopt , std::optional<TensorPtr> embeddingBias = std::nullopt ,
@@ -221,7 +222,8 @@ class GenericLlmRequest
221222 bool returnEncoderOutput = false , std::optional<RequestIdType> clientId = std::nullopt ,
222223 executor::PriorityType priority = executor::Request::kDefaultPriority , SizeType32 numReturnSequences = 1 ,
223224 std::optional<SizeType32> languageAdapterUid = std::nullopt ,
224- std::optional<executor::ContextPhaseParams> const & contextPhaseParams = std::nullopt )
225+ std::optional<executor::ContextPhaseParams> const & contextPhaseParams = std::nullopt ,
226+ std::optional<CacheSaltIDType> cacheSaltID = std::nullopt )
225227 : mRequestId (requestId)
226228 , mPromptLen (inputTokens.size())
227229 , mMaxNewTokens (maxNewTokens)
@@ -261,6 +263,7 @@ class GenericLlmRequest
261263 , mContextPhaseParams (contextPhaseParams)
262264 , mNumReturnSequences (numReturnSequences)
263265 , mLanguageAdapterUid (languageAdapterUid)
266+ , mCacheSaltID (cacheSaltID)
264267 {
265268 if (mEncoderTokens .has_value ())
266269 {
@@ -269,7 +272,6 @@ class GenericLlmRequest
269272 initialize (inputTokens, returnLogProbs);
270273 }
271274
272- // 29 items in initialization list
273275 GenericLlmRequest (RequestIdType requestId, executor::Request const & req)
274276 : mRequestId (requestId)
275277 , mPromptLen (req.getInputTokenIds().size())
@@ -300,6 +302,7 @@ class GenericLlmRequest
300302 , mGuidedDecodingParams (req.getGuidedDecodingParams())
301303 , mLanguageAdapterUid (req.getLanguageAdapterUid())
302304 , mAllottedTimeMs (req.getAllottedTimeMs())
305+ , mCacheSaltID (req.getCacheSaltID())
303306 {
304307 if (req.getRequestType () == executor::RequestType::REQUEST_TYPE_GENERATION_ONLY)
305308 {
@@ -1764,6 +1767,11 @@ class GenericLlmRequest
17641767 return mLanguageAdapterUid ;
17651768 }
17661769
1770+ [[nodiscard]] std::optional<CacheSaltIDType> getCacheSaltID () const
1771+ {
1772+ return mCacheSaltID ;
1773+ }
1774+
17671775 std::vector<SizeType32> getLanguageAdapterRouting (
17681776 SizeType32 const reqNumLanguages, SizeType32 const inputLength) const
17691777 {
@@ -2042,6 +2050,9 @@ class GenericLlmRequest
20422050
20432051 bool mUseDraftModel {false };
20442052
2053+ // Cache salt id for each request.
2054+ std::optional<CacheSaltIDType> mCacheSaltID {std::nullopt };
2055+
20452056private:
20462057 void initialize (VecTokens const & inputTokens, bool outputLogProbs)
20472058 {
@@ -2222,7 +2233,8 @@ class LlmRequest : public GenericLlmRequest<runtime::ITensor::SharedPtr>
22222233 std::optional<executor::GuidedDecodingParams> guidedDecodingParams = std::nullopt ,
22232234 std::optional<SizeType32> languageAdapterUid = std::nullopt ,
22242235 std::optional<MillisecondsType> allottedTimeMs = std::nullopt ,
2225- std::optional<executor::ContextPhaseParams> const & contextPhaseParams = std::nullopt )
2236+ std::optional<executor::ContextPhaseParams> const & contextPhaseParams = std::nullopt ,
2237+ std::optional<CacheSaltIDType> cacheSaltID = std::nullopt )
22262238 : Base(requestId, maxNewTokens, std::move(inputTokens), samplingConfig, isStreaming, endId, padId,
22272239 std::move (embeddingBias), std::move(badWordsList), std::move(stopWordsList), std::move(positionIds),
22282240 std::move(promptEmbeddingTable), promptVocabSize, std::move(multimodalHashes),
@@ -2234,7 +2246,8 @@ class LlmRequest : public GenericLlmRequest<runtime::ITensor::SharedPtr>
22342246 std::move(encoderInputTokens), returnEncoderOutput, clientId, priority, std::move(encoderInputFeatures),
22352247 std::move(encoderOutputLength), std::move(crossAttentionMask), llmRequestType,
22362248 std::move(inputTokenExtraIds), numReturnSequences, std::move(eagleConfig), std::move(skipCrossAttnBlocks),
2237- returnPerfMetrics, std::move(guidedDecodingParams), languageAdapterUid, allottedTimeMs, contextPhaseParams)
2249+ returnPerfMetrics, std::move(guidedDecodingParams), languageAdapterUid, allottedTimeMs, contextPhaseParams,
2250+ cacheSaltID)
22382251 {
22392252 }
22402253
@@ -2272,7 +2285,8 @@ class LlmRequest : public GenericLlmRequest<runtime::ITensor::SharedPtr>
22722285 std::optional<executor::GuidedDecodingParams> guidedDecodingParams = std::nullopt ,
22732286 std::optional<SizeType32> languageAdapterUid = std::nullopt ,
22742287 std::optional<MillisecondsType> allottedTimeMs = std::nullopt ,
2275- std::optional<executor::ContextPhaseParams> const & contextPhaseParams = std::nullopt )
2288+ std::optional<executor::ContextPhaseParams> const & contextPhaseParams = std::nullopt ,
2289+ std::optional<CacheSaltIDType> cacheSaltID = std::nullopt )
22762290 : Base(requestId, maxNewTokens, std::make_shared<std::vector<TokenIdType>>(std::move(inputTokens)),
22772291 samplingConfig, isStreaming, endId, padId, std::move(embeddingBias), std::move(badWordsList),
22782292 std::move(stopWordsList),
@@ -2302,7 +2316,7 @@ class LlmRequest : public GenericLlmRequest<runtime::ITensor::SharedPtr>
23022316 inputTokenExtraIds ? std::make_optional(std::make_shared<VecTokenExtraIds>(std::move(*inputTokenExtraIds)))
23032317 : std::optional<std::shared_ptr<VecTokenExtraIds>>(std::nullopt ),
23042318 numReturnSequences, std::move(eagleConfig), skipCrossAttnBlocks, returnPerfMetrics,
2305- std::move(guidedDecodingParams), languageAdapterUid, allottedTimeMs, contextPhaseParams)
2319+ std::move(guidedDecodingParams), languageAdapterUid, allottedTimeMs, contextPhaseParams, cacheSaltID )
23062320 {
23072321 }
23082322
@@ -2324,14 +2338,15 @@ class LlmRequest : public GenericLlmRequest<runtime::ITensor::SharedPtr>
23242338 bool returnEncoderOutput = false , std::optional<RequestIdType> clientId = std::nullopt ,
23252339 executor::PriorityType priority = executor::Request::kDefaultPriority , SizeType32 numReturnSequences = 1 ,
23262340 std::optional<SizeType32> languageAdapterUid = std::nullopt ,
2327- std::optional<executor::ContextPhaseParams> const & contextPhaseParams = std::nullopt )
2341+ std::optional<executor::ContextPhaseParams> const & contextPhaseParams = std::nullopt ,
2342+ std::optional<CacheSaltIDType> cacheSaltID = std::nullopt )
23282343 : Base(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming, endId, padId,
23292344 std::move (embeddingBias), std::move(badWordsList), std::move(stopWordsList), std::move(positionIds),
23302345 std::move(promptEmbeddingTable), promptVocabSize, loraTaskId, std::move(loraWeights), std::move(loraConfig),
23312346 lookaheadConfig, returnLogProbs, returnContextLogits, returnGenerationLogits, std::move(draftTokens),
23322347 std::move(draftLogits), excludeInputFromOutput, std::move(logitsPostProcessor),
23332348 applyLogitsPostProcessorBatched, std::move(encoderInputTokens), returnEncoderOutput, clientId, priority,
2334- numReturnSequences, languageAdapterUid, contextPhaseParams)
2349+ numReturnSequences, languageAdapterUid, contextPhaseParams, cacheSaltID )
23352350 {
23362351 }
23372352
0 commit comments