Skip to content

Refactor nemo examples#2006

Draft
kali wants to merge 20 commits intomainfrom
refactor-nemo-examples
Draft

Refactor nemo examples#2006
kali wants to merge 20 commits intomainfrom
refactor-nemo-examples

Conversation

@kali
Copy link
Collaborator

@kali kali commented Mar 5, 2026

No description provided.

kali and others added 20 commits March 2, 2026 15:56
Moves model loading and inference logic out of main into a TdtModel
struct. The load constructor takes impl AsRef<Path> for the model
directory, and transcribe encapsulates the full preprocessing/encoding/
TDT decoding loop.
Implement frame-synchronous beam search (BEAM_SIZE=4, DUR_BEAM_K=2)
following NeMo's BeamTDTInfer approach. Add log_softmax helper. Rename
transcribe → transcribe_greedy. Wire main to use transcribe_beam.
Track call count, average batch size, and avg/total duration for the
decoder and joint networks via a CallStats struct with a custom Debug
impl. Both transcribe_greedy and transcribe_beam return stats alongside
the transcript. main runs each decoder twice and prints stats only for
the second (post-warmup) run.
Replace the per-token decoder loop with a single batched call, building
[n,1] token and [2,n,640] state tensors, then slicing the [n,hidden] /
[2,n,640] outputs back into individual beams. Reduces kernel-launch
overhead and enables GPU parallelism; avg_batch now reflects BEAM_SIZE.
…n transcribe_beam

Each while-iteration now issues one joint call over all b active hypotheses
([b, enc_dim, 1] × [b, hidden, 1]) and one decoder call over all N token
expansions ([N,1] + [2,N,640] states), collapsing what was previously b
sequential joint calls and b sequential decoder calls into two batched calls.

dec_out shape is [batch, hidden, 1] (not [batch, hidden]), so the batched
gather uses Array3 with index [[0, h, 0]] rather than Array2 [[0, h]].

[beam][joint]   avg_batch≈3.2 (up to BEAM_SIZE=4)
[beam][decoder] avg_batch≈12.8 (up to b×BEAM_SIZE=16)
Adds BeamConfig (--beam-size, --dur-beam-k) via clap 4 derive API,
replacing the hard-coded BEAM_SIZE and DUR_BEAM_K constants.
…and ALSD

At each prune, sort by score then drop any candidate whose key was
already seen — guaranteeing the survivor is always the best-scoring one.
Key is (tokens, last_frame) for beam and (tokens, current_frame,
symbols_this_frame) for ALSD, where the extra field is needed because
hypotheses at the same frame with different symbol counts follow
different future paths.
…ssor/encoder timing, nn/host split

Replace the pair of CallStats return values with a single DecodingStats
that covers all four stages (preprocessor, encoder, decoder, joint).
The summary line now shows total elapsed, RTFx, nn time (sum of all
model calls), and host time (elapsed - nn: search logic, batching,
tensor prep).
…output

- Accept positional wav paths/dirs on the CLI (dirs recurse, sorted)
- Run a high-quality reference beam (beam_size=10, dur_beam_k=5) per file
  as silent ground truth; show transcript + duration on the file header line
- Print green ✓ / red ✗ (with ref/got lines) instead of assert_eq!
- Indent greedy/beam/alsd blocks under the file name
- Clean SentencePiece ▁ markers in displayed transcripts
- Warmup all four decoders on the first file before timing
… display options

- Select algorithm via --decoder greedy|beam|alsd (default: greedy)
- Run only reference + one decoder per file; accumulate RTFx and exact-match count
- Per-file line: filename, ✓/✗, signal duration, decoding time, RTFx, transcript
- --stats: show per-sub-model timing breakdown under each file
- --no-details: suppress per-file lines and header; show a progress bar with ETA
- Summary line (always shown): algo+params, N/total exact, overall RTFx
- Rename dur-beam-k -> beam-dur-k (and alsd equivalent) for consistent prefix nesting
- Add progress_bar 1.4.0 dependency
Add --write-gt flag to run the ground-truth beam decoder (beam_size=10,
beam_dur_k=5) and write cleaned transcripts to .txt files beside each wav.
The normal evaluation loop now reads those .txt files as reference instead
of re-running the GT decoder on every pass, making parameter search faster.
Sweep 19 hardcoded decoder configs (greedy, beam, alsd variants) over
all WAVs, printing EPR and RTFx as TSV to stdout for easy pasting into
Sheets. Two stacked indicatif progress bars on stderr track configs and
files; mp.suspend() prevents bar redraws from overwriting TSV rows.

Also replaces progress_bar crate with indicatif for write_gt and the
normal decode loop.
Output now includes pre%, enc%, dec%, joint%, host% columns showing each
component's share of total wall time, making it easy to spot where time
is spent across decoder configs.
The decoder is not a faithful implementation of alignment-length
synchronous decoding (ALSD): it does not enforce the alignment-length
synchronization invariant, and uses a per-frame symbol cap from TSD.
Rename to FBSD (Frame-asynchronous Beam Search Decoding) to reflect
what it actually does.
The published ALSD algorithm (Saon et al., ICASSP 2020) iterates over
alignment steps rather than frames. All hypotheses in the beam advance
by exactly one alignment event per step (either one token or one blank).
Completed hypotheses (current_frame >= T) are drained to a final list
at each step; the loop terminates after T + U_max steps or when the
beam empties, whichever comes first.

Key differences from FBSD:
- Outer loop bounded by T + u_max alignment steps (not open-ended)
- No per-frame symbol cap; the step bound naturally limits token chains
- Finer-grained pruning: after every single token or blank emission
- Dedup key is (tokens, frame); no symbols_this_frame dimension

Also adds alsd_* configs to --param-search (beam_size x beam_dur_k grid,
u_max=50) and exposes --decoder alsd for interactive use.
TDT predicts a duration for every step regardless of whether the token
is blank or non-blank. All decoders were ignoring the duration for
non-blank emissions, effectively treating every token as d=0. Fix:

- greedy: advance frame_ix by argmax(dur) after token emission
- beam/fbsd/alsd: compute best_dur = argmax(dur_log_probs) per hyp;
  add dur_log_probs[best_dur] to each non-blank child's score and
  advance current_frame by best_dur
- fbsd: reset symbols_this_frame to 0 when best_dur > 0 (frame changed)
NeMo sets _SOS = blank_index. With blank_as_pad=True the blank token
maps to a zero-vector embedding, giving a neutral start. All four
decoders were passing token 0 (<unk>) instead, feeding a learned
non-zero embedding at the first decoder step.
…for profiling

Makes TdtModel Runnables private; exposes run_preprocessor, run_encoder,
run_decoder, run_joint as #[inline(never)] pub(crate) methods so profiler
stack traces show named frames instead of anonymous run() calls.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant