Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 65 additions & 21 deletions seqio/vocabularies.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def encode(self, s: Union[Sequence[int], str]) -> Sequence[int]:
def _decode(self, ids):
raise NotImplementedError

def decode(self, ids: Iterable[int]):
def decode(self, ids: Iterable[int]) -> str:
"""Detokenizes int32 iterable to a string, up through first EOS."""
# A `tf.Tensor` is `Iterable` so it's valid to pass into this function.
# However, iterating over a 1D EagerTensor will create a scalar EagerTensor
Expand All @@ -109,16 +109,46 @@ def decode(self, ids: Iterable[int]):
ids: tf.Tensor = ids
ids = ids.numpy().tolist()

clean_ids = list(ids)

if self.unk_id is not None:
vocab_size = self._base_vocab_size
clean_ids = [self.unk_id if i >= vocab_size else i for i in clean_ids]

if self.eos_id is not None and self.eos_id in clean_ids:
clean_ids = clean_ids[: clean_ids.index(self.eos_id) + 1]
unk_id = self.unk_id
eos_id = self.eos_id
vocab_size = self._base_vocab_size if unk_id is not None else None

clean_ids: list[int] = []
# is_int: bool | None = None
if vocab_size is not None:
for i in ids:
# if is_int is None:
# is_int = isinstance(i, int)
# if is_int:
# i = int(i)
i = int(i)
if i >= vocab_size:
i = unk_id
clean_ids.append(i)
if i == eos_id:
break
else:
for i in ids:
# if is_int is None:
# is_int = isinstance(i, int)
# if is_int:
# i = int(i)
i = int(i)
clean_ids.append(i)
if i == eos_id:
break
# clean_ids = []
# for i in ids:
# if vocab_size is not None and i >= vocab_size:
# i = unk_id
# clean_ids.append(i)
# if i == eos_id:
# break

return self._decode(clean_ids)
try:
return self._decode(clean_ids)
except TypeError as e:
raise TypeError(f"{type(clean_ids[0])}") from e

@abc.abstractmethod
def _encode_tf(self, s: tf.Tensor) -> tf.Tensor:
Expand Down Expand Up @@ -415,6 +445,11 @@ def __init__(
self._normalizer_spec_overrides = normalizer_spec_overrides
self._reverse_extra_ids = reverse_extra_ids
self._model: Optional[_ModelContext] = None
self._cached_unk_id: Optional[int] = None
self._cached_bos_id: Optional[int] = None
self._cached_eos_id: Optional[int] = None
self._cached_pad_id: Optional[int] = None
self._cached_piece_size: Optional[int] = None
self._use_fast_tokenizer = use_fast_tokenizer

super().__init__(extra_ids=extra_ids)
Expand Down Expand Up @@ -458,19 +493,24 @@ def _model_context(
normalizer_spec_overrides_serialized,
self._reverse_extra_ids,
)
self._cached_unk_id = self._model.tokenizer.unk_id()
self._cached_bos_id = self._model.tokenizer.bos_id()
self._cached_eos_id = self._model.tokenizer.eos_id()
self._cached_pad_id = self._model.tokenizer.pad_id()
self._cached_piece_size = self._model.tokenizer.GetPieceSize()
return self._model

@property
def bos_id(self) -> Optional[int]:
return self.tokenizer.bos_id()
return self._cached_bos_id if self._model else self.tokenizer.bos_id()

@property
def eos_id(self) -> Optional[int]:
return self.tokenizer.eos_id()
return self._cached_eos_id if self._model else self.tokenizer.eos_id()

@property
def unk_id(self) -> Optional[int]:
return self.tokenizer.unk_id()
return self._cached_unk_id if self._model else self.tokenizer.unk_id()

@property
def sp_model(self) -> Optional[bytes]:
Expand All @@ -495,7 +535,11 @@ def tf_tokenizer(self):

@property
def vocab_size(self):
return self._base_vocab_size
return (
self._cached_piece_size
if self._model
else self.tokenizer.GetPieceSize()
)

@property
def _base_vocab_size(self):
Expand All @@ -504,7 +548,11 @@ def _base_vocab_size(self):
Returns:
an integer, the vocabulary size
"""
return self.tokenizer.GetPieceSize()
return (
self._cached_piece_size
if self._model
else self.tokenizer.GetPieceSize()
)

def _encode(self, s: str) -> Sequence[int]:
"""Encode a python string as a list of integers.
Expand All @@ -517,7 +565,7 @@ def _encode(self, s: str) -> Sequence[int]:
"""
return self.tokenizer.EncodeAsIds(s)

def _decode(self, ids):
def _decode(self, ids: Sequence[int]) -> str:
"""Decode a list of integers to a python string.

Args:
Expand All @@ -526,11 +574,7 @@ def _decode(self, ids):
Returns:
a string
"""
# convert all the extra ids (sentinels) to UNK=2
unk_id = self.tokenizer.unk_id()
piece_size = self.tokenizer.GetPieceSize()
ids = [unk_id if i >= piece_size else int(i) for i in ids]
return self.tokenizer.DecodeIds(ids)
return self.tokenizer.DecodeIds(list(ids))

def _encode_tf(self, s):
"""Encode a tf.Scalar string to a tf.Tensor.
Expand Down
Loading