Draft
Conversation
Moves model loading and inference logic out of main into a TdtModel struct. The load constructor takes impl AsRef<Path> for the model directory, and transcribe encapsulates the full preprocessing/encoding/ TDT decoding loop.
Implement frame-synchronous beam search (BEAM_SIZE=4, DUR_BEAM_K=2) following NeMo's BeamTDTInfer approach. Add log_softmax helper. Rename transcribe → transcribe_greedy. Wire main to use transcribe_beam.
Track call count, average batch size, and avg/total duration for the decoder and joint networks via a CallStats struct with a custom Debug impl. Both transcribe_greedy and transcribe_beam return stats alongside the transcript. main runs each decoder twice and prints stats only for the second (post-warmup) run.
Replace the per-token decoder loop with a single batched call, building [n,1] token and [2,n,640] state tensors, then slicing the [n,hidden] / [2,n,640] outputs back into individual beams. Reduces kernel-launch overhead and enables GPU parallelism; avg_batch now reflects BEAM_SIZE.
…n transcribe_beam Each while-iteration now issues one joint call over all b active hypotheses ([b, enc_dim, 1] × [b, hidden, 1]) and one decoder call over all N token expansions ([N,1] + [2,N,640] states), collapsing what was previously b sequential joint calls and b sequential decoder calls into two batched calls. dec_out shape is [batch, hidden, 1] (not [batch, hidden]), so the batched gather uses Array3 with index [[0, h, 0]] rather than Array2 [[0, h]]. [beam][joint] avg_batch≈3.2 (up to BEAM_SIZE=4) [beam][decoder] avg_batch≈12.8 (up to b×BEAM_SIZE=16)
Adds BeamConfig (--beam-size, --dur-beam-k) via clap 4 derive API, replacing the hard-coded BEAM_SIZE and DUR_BEAM_K constants.
…and ALSD At each prune, sort by score then drop any candidate whose key was already seen — guaranteeing the survivor is always the best-scoring one. Key is (tokens, last_frame) for beam and (tokens, current_frame, symbols_this_frame) for ALSD, where the extra field is needed because hypotheses at the same frame with different symbol counts follow different future paths.
…ssor/encoder timing, nn/host split Replace the pair of CallStats return values with a single DecodingStats that covers all four stages (preprocessor, encoder, decoder, joint). The summary line now shows total elapsed, RTFx, nn time (sum of all model calls), and host time (elapsed - nn: search logic, batching, tensor prep).
…output - Accept positional wav paths/dirs on the CLI (dirs recurse, sorted) - Run a high-quality reference beam (beam_size=10, dur_beam_k=5) per file as silent ground truth; show transcript + duration on the file header line - Print green ✓ / red ✗ (with ref/got lines) instead of assert_eq! - Indent greedy/beam/alsd blocks under the file name - Clean SentencePiece ▁ markers in displayed transcripts - Warmup all four decoders on the first file before timing
… display options - Select algorithm via --decoder greedy|beam|alsd (default: greedy) - Run only reference + one decoder per file; accumulate RTFx and exact-match count - Per-file line: filename, ✓/✗, signal duration, decoding time, RTFx, transcript - --stats: show per-sub-model timing breakdown under each file - --no-details: suppress per-file lines and header; show a progress bar with ETA - Summary line (always shown): algo+params, N/total exact, overall RTFx - Rename dur-beam-k -> beam-dur-k (and alsd equivalent) for consistent prefix nesting - Add progress_bar 1.4.0 dependency
Add --write-gt flag to run the ground-truth beam decoder (beam_size=10, beam_dur_k=5) and write cleaned transcripts to .txt files beside each wav. The normal evaluation loop now reads those .txt files as reference instead of re-running the GT decoder on every pass, making parameter search faster.
Sweep 19 hardcoded decoder configs (greedy, beam, alsd variants) over all WAVs, printing EPR and RTFx as TSV to stdout for easy pasting into Sheets. Two stacked indicatif progress bars on stderr track configs and files; mp.suspend() prevents bar redraws from overwriting TSV rows. Also replaces progress_bar crate with indicatif for write_gt and the normal decode loop.
Output now includes pre%, enc%, dec%, joint%, host% columns showing each component's share of total wall time, making it easy to spot where time is spent across decoder configs.
The decoder is not a faithful implementation of alignment-length synchronous decoding (ALSD): it does not enforce the alignment-length synchronization invariant, and uses a per-frame symbol cap from TSD. Rename to FBSD (Frame-asynchronous Beam Search Decoding) to reflect what it actually does.
The published ALSD algorithm (Saon et al., ICASSP 2020) iterates over alignment steps rather than frames. All hypotheses in the beam advance by exactly one alignment event per step (either one token or one blank). Completed hypotheses (current_frame >= T) are drained to a final list at each step; the loop terminates after T + U_max steps or when the beam empties, whichever comes first. Key differences from FBSD: - Outer loop bounded by T + u_max alignment steps (not open-ended) - No per-frame symbol cap; the step bound naturally limits token chains - Finer-grained pruning: after every single token or blank emission - Dedup key is (tokens, frame); no symbols_this_frame dimension Also adds alsd_* configs to --param-search (beam_size x beam_dur_k grid, u_max=50) and exposes --decoder alsd for interactive use.
TDT predicts a duration for every step regardless of whether the token is blank or non-blank. All decoders were ignoring the duration for non-blank emissions, effectively treating every token as d=0. Fix: - greedy: advance frame_ix by argmax(dur) after token emission - beam/fbsd/alsd: compute best_dur = argmax(dur_log_probs) per hyp; add dur_log_probs[best_dur] to each non-blank child's score and advance current_frame by best_dur - fbsd: reset symbols_this_frame to 0 when best_dur > 0 (frame changed)
NeMo sets _SOS = blank_index. With blank_as_pad=True the blank token maps to a zero-vector embedding, giving a neutral start. All four decoders were passing token 0 (<unk>) instead, feeding a learned non-zero embedding at the first decoder step.
…for profiling Makes TdtModel Runnables private; exposes run_preprocessor, run_encoder, run_decoder, run_joint as #[inline(never)] pub(crate) methods so profiler stack traces show named frames instead of anonymous run() calls.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
No description provided.