Skip to content

Fix EtaLogitsWarper on fully masked logits#45413

Open
ezylopx5 wants to merge 1 commit intohuggingface:mainfrom
ezylopx5:fix/eta-warper-all-inf
Open

Fix EtaLogitsWarper on fully masked logits#45413
ezylopx5 wants to merge 1 commit intohuggingface:mainfrom
ezylopx5:fix/eta-warper-all-inf

Conversation

@ezylopx5
Copy link
Copy Markdown

I ran into an edge case in eta sampling where EtaLogitsWarper crashes if a row is fully masked (scores == -inf for all tokens).

The previous entropy computation used Categorical(logits=scores).entropy(), which fails on that input. This change computes entropy from NaN-safe probabilities instead, so fully masked rows stay fully masked without raising.

I also added a regression assertion to the existing eta warper test for the all--inf case.

Local test run:
python -m pytest tests/generation/test_logits_process.py -x -q (45 passed).

AI-assisted: I used AI help for drafting, but I personally reproduced the bug, reviewed the diff, and ran local tests.

Signed-off-by: HarshRathva <harshrathvaai@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant