|
16 | 16 | import tensorflow_text as tf_text |
17 | 17 |
|
18 | 18 | try: |
19 | | - import keras_nlp |
20 | 19 | from keras_nlp.tokenizers import BytePairTokenizer |
21 | 20 | except ImportError: |
22 | | - keras_nlp = None |
| 21 | + BytePairTokenizer = None |
23 | 22 |
|
24 | 23 | # As python and TF handles special spaces differently, we need to |
25 | 24 | # manually handle special spaces during string split. |
@@ -104,83 +103,82 @@ def remove_strings_from_inputs(tensor, string_to_remove): |
104 | 103 | return result |
105 | 104 |
|
106 | 105 |
|
107 | | -class CLIPTokenizer(BytePairTokenizer): |
108 | | - def __init__(self, **kwargs): |
109 | | - super().__init__(**kwargs) |
110 | | - if keras_nlp is None: |
111 | | - raise ValueError( |
112 | | - "ClipTokenizer requires keras-nlp. Please install " |
113 | | - "using pip `pip install -U keras-nlp && pip install -U keras`" |
| 106 | +if BytePairTokenizer: |
| 107 | + class CLIPTokenizer(BytePairTokenizer): |
| 108 | + def __init__(self, **kwargs): |
| 109 | + super().__init__(**kwargs) |
| 110 | + |
| 111 | + def _bpe_merge_and_update_cache(self, tokens): |
| 112 | + """Process unseen tokens and add to cache.""" |
| 113 | + words = self._transform_bytes(tokens) |
| 114 | + tokenized_words = self._bpe_merge(words) |
| 115 | + |
| 116 | + # For each word, join all its token by a whitespace, |
| 117 | + # e.g., ["dragon", "fly"] => "dragon fly" for hash purpose. |
| 118 | + tokenized_words = tf.strings.reduce_join( |
| 119 | + tokenized_words, |
| 120 | + axis=1, |
| 121 | + ) |
| 122 | + self.cache.insert(tokens, tokenized_words) |
| 123 | + |
| 124 | + def tokenize(self, inputs): |
| 125 | + self._check_vocabulary() |
| 126 | + if not isinstance(inputs, (tf.Tensor, tf.RaggedTensor)): |
| 127 | + inputs = tf.convert_to_tensor(inputs) |
| 128 | + |
| 129 | + if self.add_prefix_space: |
| 130 | + inputs = tf.strings.join([" ", inputs]) |
| 131 | + |
| 132 | + scalar_input = inputs.shape.rank == 0 |
| 133 | + if scalar_input: |
| 134 | + inputs = tf.expand_dims(inputs, 0) |
| 135 | + |
| 136 | + raw_tokens = split_strings_for_bpe(inputs, self.unsplittable_tokens) |
| 137 | + token_row_splits = raw_tokens.row_splits |
| 138 | + flat_tokens = raw_tokens.flat_values |
| 139 | + # Check cache. |
| 140 | + cache_lookup = self.cache.lookup(flat_tokens) |
| 141 | + cache_mask = cache_lookup == "" |
| 142 | + |
| 143 | + has_unseen_words = tf.math.reduce_any( |
| 144 | + (cache_lookup == "") & (flat_tokens != "") |
| 145 | + ) |
| 146 | + |
| 147 | + def process_unseen_tokens(): |
| 148 | + unseen_tokens = tf.boolean_mask(flat_tokens, cache_mask) |
| 149 | + self._bpe_merge_and_update_cache(unseen_tokens) |
| 150 | + return self.cache.lookup(flat_tokens) |
| 151 | + |
| 152 | + # If `has_unseen_words == True`, it means not all tokens are in cache, |
| 153 | + # we will process the unseen tokens. Otherwise return the cache lookup. |
| 154 | + tokenized_words = tf.cond( |
| 155 | + has_unseen_words, |
| 156 | + process_unseen_tokens, |
| 157 | + lambda: cache_lookup, |
| 158 | + ) |
| 159 | + tokens = tf.strings.split(tokenized_words, sep=" ") |
| 160 | + if self.compute_dtype != tf.string: |
| 161 | + # Encode merged tokens. |
| 162 | + tokens = self.token_to_id_map.lookup(tokens) |
| 163 | + |
| 164 | + # Unflatten to match input. |
| 165 | + tokens = tf.RaggedTensor.from_row_splits( |
| 166 | + tokens.flat_values, |
| 167 | + tf.gather(tokens.row_splits, token_row_splits), |
114 | 168 | ) |
115 | 169 |
|
116 | | - def _bpe_merge_and_update_cache(self, tokens): |
117 | | - """Process unseen tokens and add to cache.""" |
118 | | - words = self._transform_bytes(tokens) |
119 | | - tokenized_words = self._bpe_merge(words) |
120 | | - |
121 | | - # For each word, join all its token by a whitespace, |
122 | | - # e.g., ["dragon", "fly"] => "dragon fly" for hash purpose. |
123 | | - tokenized_words = tf.strings.reduce_join( |
124 | | - tokenized_words, |
125 | | - axis=1, |
126 | | - ) |
127 | | - self.cache.insert(tokens, tokenized_words) |
128 | | - |
129 | | - def tokenize(self, inputs): |
130 | | - self._check_vocabulary() |
131 | | - if not isinstance(inputs, (tf.Tensor, tf.RaggedTensor)): |
132 | | - inputs = tf.convert_to_tensor(inputs) |
133 | | - |
134 | | - if self.add_prefix_space: |
135 | | - inputs = tf.strings.join([" ", inputs]) |
136 | | - |
137 | | - scalar_input = inputs.shape.rank == 0 |
138 | | - if scalar_input: |
139 | | - inputs = tf.expand_dims(inputs, 0) |
140 | | - |
141 | | - raw_tokens = split_strings_for_bpe(inputs, self.unsplittable_tokens) |
142 | | - token_row_splits = raw_tokens.row_splits |
143 | | - flat_tokens = raw_tokens.flat_values |
144 | | - # Check cache. |
145 | | - cache_lookup = self.cache.lookup(flat_tokens) |
146 | | - cache_mask = cache_lookup == "" |
147 | | - |
148 | | - has_unseen_words = tf.math.reduce_any( |
149 | | - (cache_lookup == "") & (flat_tokens != "") |
150 | | - ) |
151 | | - |
152 | | - def process_unseen_tokens(): |
153 | | - unseen_tokens = tf.boolean_mask(flat_tokens, cache_mask) |
154 | | - self._bpe_merge_and_update_cache(unseen_tokens) |
155 | | - return self.cache.lookup(flat_tokens) |
156 | | - |
157 | | - # If `has_unseen_words == True`, it means not all tokens are in cache, |
158 | | - # we will process the unseen tokens. Otherwise return the cache lookup. |
159 | | - tokenized_words = tf.cond( |
160 | | - has_unseen_words, |
161 | | - process_unseen_tokens, |
162 | | - lambda: cache_lookup, |
163 | | - ) |
164 | | - tokens = tf.strings.split(tokenized_words, sep=" ") |
165 | | - if self.compute_dtype != tf.string: |
166 | | - # Encode merged tokens. |
167 | | - tokens = self.token_to_id_map.lookup(tokens) |
168 | | - |
169 | | - # Unflatten to match input. |
170 | | - tokens = tf.RaggedTensor.from_row_splits( |
171 | | - tokens.flat_values, |
172 | | - tf.gather(tokens.row_splits, token_row_splits), |
173 | | - ) |
174 | | - |
175 | | - # Convert to a dense output if `sequence_length` is set. |
176 | | - if self.sequence_length: |
177 | | - output_shape = tokens.shape.as_list() |
178 | | - output_shape[-1] = self.sequence_length |
179 | | - tokens = tokens.to_tensor(shape=output_shape) |
180 | | - |
181 | | - # Convert to a dense output if input in scalar |
182 | | - if scalar_input: |
183 | | - tokens = tf.squeeze(tokens, 0) |
184 | | - tf.ensure_shape(tokens, shape=[self.sequence_length]) |
185 | | - |
186 | | - return tokens |
| 170 | + # Convert to a dense output if `sequence_length` is set. |
| 171 | + if self.sequence_length: |
| 172 | + output_shape = tokens.shape.as_list() |
| 173 | + output_shape[-1] = self.sequence_length |
| 174 | + tokens = tokens.to_tensor(shape=output_shape) |
| 175 | + |
| 176 | + # Convert to a dense output if input in scalar |
| 177 | + if scalar_input: |
| 178 | + tokens = tf.squeeze(tokens, 0) |
| 179 | + tf.ensure_shape(tokens, shape=[self.sequence_length]) |
| 180 | + |
| 181 | + return tokens |
| 182 | + |
| 183 | +else: |
| 184 | + CLIPTokenizer = None |
0 commit comments