@@ -103,23 +103,21 @@ void tensorrt_llm::nanobind::batch_manager::algorithms::initBindings(nb::module_
103103 " __call__" ,
104104 [](CreateNewDecoderRequests& self, tr::ModelConfig const & modelConfig, tr::WorldConfig const & worldConfig,
105105 executor::DecodingConfig const & decodingConfig, RequestVector const & contextRequests,
106- tr::BufferManager const & bufferManager, nvinfer1::DataType logitsType,
107- DecoderInputBuffers& inputBuffers, runtime::decoder::DecoderState& decoderState,
108- tensorrt_llm::runtime::CudaStream const & runtimeStream,
106+ nvinfer1::DataType logitsType, DecoderInputBuffers& inputBuffers,
107+ runtime::decoder::DecoderState& decoderState, tensorrt_llm::runtime::CudaStream const & runtimeStream,
109108 tensorrt_llm::runtime::CudaStream const & decoderStream, SizeType32 maxSequenceLength,
110109 SizeType32 beamWidth)
111110 {
112111 OptionalRef<MedusaBuffers const > medusaBuffers = std::nullopt ;
113- auto [batchSlots, samplingConfigs, lookaheadPrompt, lookaheadAlgoConfigs] = self (modelConfig,
114- worldConfig, decodingConfig, contextRequests, bufferManager, logitsType, inputBuffers, decoderState ,
115- runtimeStream, decoderStream, maxSequenceLength, beamWidth, medusaBuffers);
112+ auto [batchSlots, samplingConfigs, lookaheadPrompt, lookaheadAlgoConfigs]
113+ = self (modelConfig, worldConfig, decodingConfig, contextRequests, logitsType, inputBuffers,
114+ decoderState, runtimeStream, decoderStream, maxSequenceLength, beamWidth, medusaBuffers);
116115
117116 return std::tuple{runtime::Torch::tensor (batchSlots), std::move (samplingConfigs),
118117 std::move (lookaheadPrompt), std::move (lookaheadAlgoConfigs)};
119118 },
120119 nb::arg (" model_config" ), nb::arg (" world_config" ), nb::arg (" decoding_config" ), nb::arg (" context_requests" ),
121- nb::arg (" buffer_manager" ), nb::arg (" logits_type" ), nb::arg (" decoder_input_buffers" ),
122- nb::arg (" decoder_state" ), nb::arg (" runtime_stream" ), nb::arg (" decoder_stream" ),
123- nb::arg (" max_sequence_length" ), nb::arg (" beam_width" ))
120+ nb::arg (" logits_type" ), nb::arg (" decoder_input_buffers" ), nb::arg (" decoder_state" ),
121+ nb::arg (" runtime_stream" ), nb::arg (" decoder_stream" ), nb::arg (" max_sequence_length" ), nb::arg (" beam_width" ))
124122 .def (" name" , [](CreateNewDecoderRequests const &) { return CreateNewDecoderRequests::name; });
125123}
0 commit comments