Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions src/transformers/generation/logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -997,9 +997,13 @@ def __init__(
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
probabilities = scores.softmax(dim=-1)
entropy = torch.distributions.Categorical(logits=scores).entropy()
# `softmax(-inf)` yields NaN when all scores are masked. We treat such rows as having zero probability mass
# to keep eta warping stable and preserve the fully masked state.
safe_probabilities = torch.nan_to_num(probabilities, nan=0.0)
safe_log_probabilities = safe_probabilities.clamp_min(torch.finfo(scores.dtype).tiny).log()
entropy = -(safe_probabilities * safe_log_probabilities).sum(dim=-1)
eta = torch.min(self.epsilon, torch.sqrt(self.epsilon) * torch.exp(-entropy))[..., None]
indices_to_remove = probabilities < eta
indices_to_remove = safe_probabilities < eta

# Keep the words with the 'min_tokens_to_keep'-highest probabilities
top_k = min(self.min_tokens_to_keep, scores.size(-1)) # Safety check
Expand Down
6 changes: 6 additions & 0 deletions tests/generation/test_logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,6 +657,12 @@ def test_eta_dist_warper(self):
# first batch should keep 2 tokens, second batch would keep only 1, but due to `min_tokens_to_keep=2` keeps 2.
self.assertListEqual((filtered_dist != 0.0).to(torch.long).sum(dim=-1).tolist(), [2, 2])

# eta warper should keep fully masked rows stable (all -inf) instead of erroring due to NaN entropy.
fully_masked_scores = torch.full((1, vocab_size), -float("inf"), device=torch_device, dtype=torch.float)
masked_out = eta_warp(input_ids, fully_masked_scores)
self.assertFalse(torch.isnan(masked_out).any())
self.assertTrue(torch.isneginf(masked_out).all())

def test_no_repeat_ngram_dist_processor(self):
vocab_size = 3
batch_size = 2
Expand Down
Loading