diff --git a/seqio/vocabularies.py b/seqio/vocabularies.py index 596f7f6d..ab5a0eb7 100644 --- a/seqio/vocabularies.py +++ b/seqio/vocabularies.py @@ -99,7 +99,7 @@ def encode(self, s: Union[Sequence[int], str]) -> Sequence[int]: def _decode(self, ids): raise NotImplementedError - def decode(self, ids: Iterable[int]): + def decode(self, ids: Iterable[int]) -> str: """Detokenizes int32 iterable to a string, up through first EOS.""" # A `tf.Tensor` is `Iterable` so it's valid to pass into this function. # However, iterating over a 1D EagerTensor will create a scalar EagerTensor @@ -109,16 +109,46 @@ def decode(self, ids: Iterable[int]): ids: tf.Tensor = ids ids = ids.numpy().tolist() - clean_ids = list(ids) - - if self.unk_id is not None: - vocab_size = self._base_vocab_size - clean_ids = [self.unk_id if i >= vocab_size else i for i in clean_ids] - - if self.eos_id is not None and self.eos_id in clean_ids: - clean_ids = clean_ids[: clean_ids.index(self.eos_id) + 1] + unk_id = self.unk_id + eos_id = self.eos_id + vocab_size = self._base_vocab_size if unk_id is not None else None + + clean_ids: list[int] = [] + # is_int: bool | None = None + if vocab_size is not None: + for i in ids: + # if is_int is None: + # is_int = isinstance(i, int) + # if is_int: + # i = int(i) + i = int(i) + if i >= vocab_size: + i = unk_id + clean_ids.append(i) + if i == eos_id: + break + else: + for i in ids: + # if is_int is None: + # is_int = isinstance(i, int) + # if is_int: + # i = int(i) + i = int(i) + clean_ids.append(i) + if i == eos_id: + break + # clean_ids = [] + # for i in ids: + # if vocab_size is not None and i >= vocab_size: + # i = unk_id + # clean_ids.append(i) + # if i == eos_id: + # break - return self._decode(clean_ids) + try: + return self._decode(clean_ids) + except TypeError as e: + raise TypeError(f"{type(clean_ids[0])}") from e @abc.abstractmethod def _encode_tf(self, s: tf.Tensor) -> tf.Tensor: @@ -415,6 +445,11 @@ def __init__( self._normalizer_spec_overrides = normalizer_spec_overrides self._reverse_extra_ids = reverse_extra_ids self._model: Optional[_ModelContext] = None + self._cached_unk_id: Optional[int] = None + self._cached_bos_id: Optional[int] = None + self._cached_eos_id: Optional[int] = None + self._cached_pad_id: Optional[int] = None + self._cached_piece_size: Optional[int] = None self._use_fast_tokenizer = use_fast_tokenizer super().__init__(extra_ids=extra_ids) @@ -458,19 +493,24 @@ def _model_context( normalizer_spec_overrides_serialized, self._reverse_extra_ids, ) + self._cached_unk_id = self._model.tokenizer.unk_id() + self._cached_bos_id = self._model.tokenizer.bos_id() + self._cached_eos_id = self._model.tokenizer.eos_id() + self._cached_pad_id = self._model.tokenizer.pad_id() + self._cached_piece_size = self._model.tokenizer.GetPieceSize() return self._model @property def bos_id(self) -> Optional[int]: - return self.tokenizer.bos_id() + return self._cached_bos_id if self._model else self.tokenizer.bos_id() @property def eos_id(self) -> Optional[int]: - return self.tokenizer.eos_id() + return self._cached_eos_id if self._model else self.tokenizer.eos_id() @property def unk_id(self) -> Optional[int]: - return self.tokenizer.unk_id() + return self._cached_unk_id if self._model else self.tokenizer.unk_id() @property def sp_model(self) -> Optional[bytes]: @@ -495,7 +535,11 @@ def tf_tokenizer(self): @property def vocab_size(self): - return self._base_vocab_size + return ( + self._cached_piece_size + if self._model + else self.tokenizer.GetPieceSize() + ) @property def _base_vocab_size(self): @@ -504,7 +548,11 @@ def _base_vocab_size(self): Returns: an integer, the vocabulary size """ - return self.tokenizer.GetPieceSize() + return ( + self._cached_piece_size + if self._model + else self.tokenizer.GetPieceSize() + ) def _encode(self, s: str) -> Sequence[int]: """Encode a python string as a list of integers. @@ -517,7 +565,7 @@ def _encode(self, s: str) -> Sequence[int]: """ return self.tokenizer.EncodeAsIds(s) - def _decode(self, ids): + def _decode(self, ids: Sequence[int]) -> str: """Decode a list of integers to a python string. Args: @@ -526,11 +574,7 @@ def _decode(self, ids): Returns: a string """ - # convert all the extra ids (sentinels) to UNK=2 - unk_id = self.tokenizer.unk_id() - piece_size = self.tokenizer.GetPieceSize() - ids = [unk_id if i >= piece_size else int(i) for i in ids] - return self.tokenizer.DecodeIds(ids) + return self.tokenizer.DecodeIds(list(ids)) def _encode_tf(self, s): """Encode a tf.Scalar string to a tf.Tensor.