@@ -14,10 +14,10 @@ using torch::headeronly::ScalarType;
1414
1515// Entry point into RNNT Loss
1616std::tuple<Tensor, Tensor> compute (
17- const Tensor& logits,
18- const Tensor& targets,
19- const Tensor& logit_lengths,
20- const Tensor& target_lengths,
17+ Tensor logits,
18+ Tensor targets,
19+ Tensor logit_lengths,
20+ Tensor target_lengths,
2121 int64_t blank,
2222 double clamp,
2323 bool fused_log_softmax = true ) {
@@ -148,23 +148,13 @@ std::tuple<Tensor, Tensor> compute(
148148 return std::make_tuple (costs, gradients);
149149}
150150
151- void boxed_rnnt_loss (StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
152- STD_TORCH_CHECK (num_args == 7 , " num_args must be 7" );
153- STD_TORCH_CHECK (num_outputs == 2 , " num_outputs must be 2" );
154- std::tuple<Tensor, Tensor> res = compute (
155- /* logits*/ torch::stable::detail::to<Tensor>(stack[0 ]),
156- /* targets*/ torch::stable::detail::to<Tensor>(stack[1 ]),
157- /* logit_lengths*/ torch::stable::detail::to<Tensor>(stack[2 ]),
158- /* target_lengths*/ torch::stable::detail::to<Tensor>(stack[3 ]),
159- /* blank*/ float (torch::stable::detail::to<int64_t >(stack[4 ])),
160- /* clamp*/ torch::stable::detail::to<double >(stack[5 ]),
161- /* fused_log_softmax*/ torch::stable::detail::to<bool >(stack[6 ]));
162- stack[0 ] = torch::stable::detail::from (std::get<0 >(res));
163- stack[1 ] = torch::stable::detail::from (std::get<1 >(res));
151+ STABLE_TORCH_LIBRARY_FRAGMENT (torchaudio, m) {
152+ m.def (
153+ " rnnt_loss_forward(Tensor logits, Tensor targets, Tensor logit_lengths, Tensor target_lengths, int blank, double clamp, bool fused_log_softmax) -> (Tensor, Tensor)" );
164154}
165155
166156STABLE_TORCH_LIBRARY_IMPL (torchaudio, CUDA, m) {
167- m.impl (" rnnt_loss_forward" , &boxed_rnnt_loss );
157+ m.impl (" rnnt_loss_forward" , TORCH_BOX (&compute) );
168158}
169159
170160} // namespace gpu
0 commit comments