Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 93 additions & 1 deletion flair/embeddings/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,58 @@ def document_max_pooling(sentence_hidden_states: torch.Tensor, sentence_lengths:
return result


def map_tokens_to_subtokens(subtoken_offsets, token_offsets, verbose: bool = False, subtokens=None, tokens=None):

mapping: list[Optional[int]] = []
for subtoken_id, subtoken in enumerate(subtoken_offsets):

# subtokens of length 0 should not be mapped to anything
if subtoken[0] == subtoken[1]:
mapping.append(None)
continue

mapping_found = False

if verbose and subtokens:
print(f"trying to match {subtokens[subtoken_id]} ({subtoken})")

# check if the subtoken is wholly contained within a token. If so, it should be mapped to this token
for token_id, token in enumerate(token_offsets):

if verbose and tokens:
print(f" ... does {tokens[token_id]} (#{token_id}, {token}) match?")

if token[0] - 1 <= subtoken[0] and token[1] >= subtoken[1]:
if verbose:
print(" ... yes!")
mapping.append(token_id)
mapping_found = True
break

if mapping_found:
continue

# if the subtoken is not wholly contained within a token, it may be partially contained
# in this case, take the first token in which it is partially contained
for token_id, token in enumerate(token_offsets):
if verbose and tokens:
print(f" ... does {tokens[token_id]} (#{token_id}, {token}) partially match?")
if token[0] >= subtoken[0]:
if verbose:
print(" ... yes!")
mapping.append(token_id)
mapping_found = True
break

if mapping_found:
continue

# if a subtoken cannot be mapped, the mapping is None
mapping.append(None)

return mapping


def _legacy_reconstruct_word_ids(
embedding: "TransformerBaseEmbeddings", flair_tokens: list[list[str]]
) -> list[list[Optional[int]]]:
Expand Down Expand Up @@ -354,6 +406,8 @@ def __init__(
feature_extractor: Optional[FeatureExtractionMixin] = None,
needs_manual_ocr: Optional[bool] = None,
use_context_separator: bool = True,
use_raw_text_as_input: bool = False,
**kwargs,
) -> None:
self.name = name
super().__init__()
Expand All @@ -374,6 +428,7 @@ def __init__(
self.feature_extractor = feature_extractor
self.use_context_separator = use_context_separator
self.cls_pooling = cls_pooling
self.use_raw_text_as_input = use_raw_text_as_input

tokenizer_params = list(inspect.signature(self.tokenizer.__call__).parameters.keys())
self.tokenizer_needs_ocr_boxes = "boxes" in tokenizer_params
Expand Down Expand Up @@ -417,6 +472,7 @@ def to_args(self):
"feature_extractor": self.feature_extractor,
"use_context_separator": self.use_context_separator,
"cls_pooling": self.cls_pooling,
"use_raw_text_as_input": self.use_raw_text_as_input,
}
if hasattr(self, "needs_manual_ocr"):
args["needs_manual_ocr"] = self.needs_manual_ocr
Expand Down Expand Up @@ -568,6 +624,17 @@ def __build_transformer_model_inputs(
tokenizer_kwargs["is_split_into_words"] = True
tokenizer_kwargs["text"] = [[t.text for t in tokens] for tokens in flair_tokens]

# if we use raw text as input #TODO: explain
if self.use_raw_text_as_input:
tokenizer_kwargs["is_split_into_words"] = False
tokenizer_kwargs["return_offsets_mapping"] = True

# reconstruct text of sentences and preserve whitespace_after information
tokenizer_kwargs["text"] = [
"".join([t.text if t.whitespace_after == 0 else t.text + " " * t.whitespace_after for t in tokens])
for tokens in flair_tokens
]

batch_encoding = self.tokenizer(
**tokenizer_kwargs,
stride=self.stride,
Expand Down Expand Up @@ -627,12 +694,35 @@ def __build_transformer_model_inputs(
if "bbox" in batch_encoding:
model_kwargs["bbox"] = batch_encoding["bbox"].to(device, non_blocking=True)

# If we need a token-level embedding, we need to derive mappings between subtokens and flair tokens
if self.token_embedding or self.needs_manual_ocr:
assert sentence_lengths is not None # for type checking
model_kwargs["token_lengths"] = torch.tensor(sentence_lengths, device=device)

if self.tokenizer.is_fast:
word_ids_list = [batch_encoding.word_ids(i) for i in range(input_ids.size()[0])]

if self.use_raw_text_as_input:
word_ids_list = []
assert flair_tokens # assert that this is not None for mypy type checking
for sentence_no, sentence_tokens in enumerate(flair_tokens):

subtoken_offsets = batch_encoding["offset_mapping"][sentence_no]

offset = 0
token_offsets = []
for token in sentence_tokens:
token_offsets.append((offset, offset + len(token.text)))
offset += len(token.text) + token.whitespace_after

mapping = map_tokens_to_subtokens(
subtoken_offsets=subtoken_offsets,
token_offsets=token_offsets,
)

word_ids_list.append(mapping)

else:
word_ids_list = [batch_encoding.word_ids(i) for i in range(input_ids.size()[0])]
else:
word_ids_list = _legacy_reconstruct_word_ids(
self,
Expand Down Expand Up @@ -1053,6 +1143,7 @@ def __init__(
transformers_model_kwargs: dict[str, Any] = {},
peft_config=None,
peft_gradient_checkpointing_kwargs: Optional[dict[str, Any]] = {},
use_raw_text_as_input: bool = False,
**kwargs,
) -> None:
"""Instantiate transformers embeddings.
Expand Down Expand Up @@ -1099,6 +1190,7 @@ def __init__(
logging.set_verbosity_error()

self.tokenizer: PreTrainedTokenizer
self.use_raw_text_as_input = use_raw_text_as_input
self.feature_extractor: Optional[FeatureExtractionMixin]

if tokenizer_data is None:
Expand Down
76 changes: 76 additions & 0 deletions tests/embeddings/test_transformer_word_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
import pytest
import torch
from PIL import Image
from torch import tensor
from transformers.utils import is_detectron2_available

from flair.data import BoundingBox, Dictionary, Sentence
from flair.embeddings import TransformerJitWordEmbeddings, TransformerWordEmbeddings
from flair.embeddings.transformer import map_tokens_to_subtokens
from flair.models import SequenceTagger
from tests.embedding_test_utils import BaseEmbeddingsTest

Expand Down Expand Up @@ -323,3 +325,77 @@ def test_onnx_export_works(self, results_base_path):
for sent_a, sent_b in zip(normal_sentences, onnx_sentences):
for token_a, token_b in zip(sent_a, sent_b):
assert torch.isclose(token_a.get_embedding(), token_b.get_embedding(), atol=1e-6).all()

def test_token_subtoken_mapping(self):
### Test Case 1: Normal text
# text = "BEST DENTIST EVER -"

# Token and subtoken offsets
# tokens = ["[FLERT]", "BEST", "DENTIST", "EVER", "-", "[FLERT]"]
token_offsets = [(0, 7), (8, 12), (13, 20), (21, 25), (26, 27), (27, 34)]

# subtokens = ["[CLS]", "[FLERT]", "▁BEST", "▁D", "ENT", "IST", "▁EVER", "▁-", "[FLERT]", "[SEP]", ]
subtoken_offsets = tensor(
[[0, 0], [0, 7], [8, 12], [12, 14], [14, 17], [17, 20], [20, 25], [25, 27], [27, 34], [0, 0]]
)

mapping = map_tokens_to_subtokens(subtoken_offsets=subtoken_offsets, token_offsets=token_offsets)

assert [None, 0, 1, 2, 2, 2, 3, 4, 5, None] == mapping

### Test Case 2: Differing tokenizations
# text = "So don't be afraid"

# Token and subtoken offsets
# tokens = ["[FLERT]", "So", "do", "n't", "be", "afraid", "[FLERT]"]
token_offsets = [(0, 7), (8, 10), (11, 13), (13, 16), (17, 19), (20, 26), (26, 33)]

# subtokens = ["[CLS]", "[FLERT]", "▁So", "▁don", "'", "t", "▁be", "▁afraid", "[FLERT]", "[SEP]"]
subtoken_offsets = tensor(
[[0, 0], [0, 7], [8, 10], [10, 14], [14, 15], [15, 16], [16, 19], [19, 26], [26, 33], [0, 0]]
)

mapping = map_tokens_to_subtokens(subtoken_offsets=subtoken_offsets, token_offsets=token_offsets)

assert [None, 0, 1, 2, 3, 3, 4, 5, 6, None] == mapping

### Test Case 3: Text with punctuation and no whitespaces
# text = "this and/or that,"

# Token and subtoken offsets
# tokens = ["[FLERT]", "this", "and", "/", "or", "that", ",", "[FLERT]"]
token_offsets = [(0, 7), (8, 12), (13, 16), (16, 17), (17, 19), (20, 24), (24, 25), (25, 32)]

# subtokens = ["[CLS]", "[FLERT]", "▁this", "▁and", "/", "or", "▁that", ",", "[FLERT]", "[SEP]"]
subtoken_offsets = tensor(
[[0, 0], [0, 7], [8, 12], [12, 16], [16, 17], [17, 19], [19, 24], [24, 25], [25, 32], [0, 0]]
)

mapping = map_tokens_to_subtokens(subtoken_offsets=subtoken_offsets, token_offsets=token_offsets)

assert [None, 0, 1, 2, 3, 4, 5, 6, 7, None] == mapping

### Test Case 4: Suboptimal tokenization caused by limited vocabulary without whitespace
# text = "number of public-diplomacy officers"

# Token and subtoken offsets
# tokens = ['number', 'of', 'public', '-', 'diplomacy', 'officers']
token_offsets = [(0, 6), (7, 9), (10, 16), (16, 17), (17, 26), (27, 35)]

# new_subtokens = ['[CLS]', '▁number', '▁of', '▁public', '-', 'diploma', 'cy', '▁officers', '[SEP]']
# old_subtokens = ['[CLS]', '▁number', '▁of', '▁public', '▁-', '▁diplomacy', '▁officers', '[SEP]']
subtoken_offsets = tensor([[0, 0], [0, 6], [6, 9], [9, 16], [16, 17], [17, 24], [24, 26], [26, 35], [0, 0]])

assert [None, 0, 1, 2, 3, 4, 5, 6, 7, None] == mapping

### Test Case 5: Suboptimal tokenization in which two tokenizer words become one subtoken ("wan" "na" -> "wanna")
# text = "I gotta have it"

# Token and subtoken offsets
# tokens = ['I', 'got', 'ta', 'have', 'it']
token_offsets = [(0, 1), (2, 5), (5, 7), (8, 12), (13, 15)]

# new subtokens = ['[CLS]', '▁I', '▁gotta', '▁have', '▁it', '[SEP]']
# old subtokens = ['[CLS]', '▁I', '▁got', '▁ta', '▁have', '▁it', '[SEP]']
subtoken_offsets = tensor([[0, 0], [0, 1], [1, 7], [7, 12], [12, 15], [0, 0]])
assert [None, 0, 1, 2, 3, 4, 5, 6, 7, None] == mapping
Loading