Skip to content

Commit 8265a17

Browse files
committed
Use conditional keras_nlp imports
1 parent fb05c82 commit 8265a17

File tree

4 files changed

+101
-87
lines changed

4 files changed

+101
-87
lines changed

keras_cv/models/feature_extractor/clip/clip_model.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
CLIPTextEncoder,
2727
)
2828
from keras_cv.models.task import Task
29+
from keras_cv.utils.conditional_imports import assert_keras_nlp_installed
2930
from keras_cv.utils.python_utils import classproperty
3031

3132
try:
@@ -98,11 +99,7 @@ def __init__(
9899
**kwargs,
99100
):
100101
super().__init__(**kwargs)
101-
if keras_nlp is None:
102-
raise ValueError(
103-
"ClipTokenizer requires keras-nlp. Please install "
104-
"using pip `pip install -U keras-nlp && pip install -U keras`"
105-
)
102+
assert_keras_nlp_installed("CLIP")
106103
self.embed_dim = embed_dim
107104
self.image_resolution = image_resolution
108105
self.vision_layers = vision_layers

keras_cv/models/feature_extractor/clip/clip_processor.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,17 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from keras_nlp.layers import StartEndPacker
1514

1615
from keras_cv.api_export import keras_cv_export
1716
from keras_cv.backend import keras
1817
from keras_cv.backend import ops
1918
from keras_cv.models.feature_extractor.clip.clip_tokenizer import CLIPTokenizer
19+
from keras_cv.utils.conditional_imports import assert_keras_nlp_installed
2020

21+
try:
22+
import keras_nlp
23+
except ImportError:
24+
keras_nlp = None
2125

2226
@keras_cv_export("keras_cv.models.feature_extractor.CLIPProcessor")
2327
class CLIPProcessor:
@@ -45,6 +49,7 @@ class CLIPProcessor:
4549
"""
4650

4751
def __init__(self, input_resolution, vocabulary, merges, **kwargs):
52+
assert_keras_nlp_installed("CLIPProcessor")
4853
self.input_resolution = input_resolution
4954
self.vocabulary = vocabulary
5055
self.merges = merges
@@ -54,7 +59,7 @@ def __init__(self, input_resolution, vocabulary, merges, **kwargs):
5459
merges=self.merges,
5560
unsplittable_tokens=["</w>"],
5661
)
57-
self.packer = StartEndPacker(
62+
self.packer = keras_nlp.layers.StartEndPacker(
5863
start_value=self.tokenizer.token_to_id("<|startoftext|>"),
5964
end_value=self.tokenizer.token_to_id("<|endoftext|>"),
6065
pad_value=None,

keras_cv/models/feature_extractor/clip/clip_tokenizer.py

Lines changed: 78 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,9 @@
1616
import tensorflow_text as tf_text
1717

1818
try:
19-
import keras_nlp
2019
from keras_nlp.tokenizers import BytePairTokenizer
2120
except ImportError:
22-
keras_nlp = None
21+
BytePairTokenizer = None
2322

2423
# As python and TF handles special spaces differently, we need to
2524
# manually handle special spaces during string split.
@@ -104,83 +103,82 @@ def remove_strings_from_inputs(tensor, string_to_remove):
104103
return result
105104

106105

107-
class CLIPTokenizer(BytePairTokenizer):
108-
def __init__(self, **kwargs):
109-
super().__init__(**kwargs)
110-
if keras_nlp is None:
111-
raise ValueError(
112-
"ClipTokenizer requires keras-nlp. Please install "
113-
"using pip `pip install -U keras-nlp && pip install -U keras`"
106+
if BytePairTokenizer:
107+
class CLIPTokenizer(BytePairTokenizer):
108+
def __init__(self, **kwargs):
109+
super().__init__(**kwargs)
110+
111+
def _bpe_merge_and_update_cache(self, tokens):
112+
"""Process unseen tokens and add to cache."""
113+
words = self._transform_bytes(tokens)
114+
tokenized_words = self._bpe_merge(words)
115+
116+
# For each word, join all its token by a whitespace,
117+
# e.g., ["dragon", "fly"] => "dragon fly" for hash purpose.
118+
tokenized_words = tf.strings.reduce_join(
119+
tokenized_words,
120+
axis=1,
121+
)
122+
self.cache.insert(tokens, tokenized_words)
123+
124+
def tokenize(self, inputs):
125+
self._check_vocabulary()
126+
if not isinstance(inputs, (tf.Tensor, tf.RaggedTensor)):
127+
inputs = tf.convert_to_tensor(inputs)
128+
129+
if self.add_prefix_space:
130+
inputs = tf.strings.join([" ", inputs])
131+
132+
scalar_input = inputs.shape.rank == 0
133+
if scalar_input:
134+
inputs = tf.expand_dims(inputs, 0)
135+
136+
raw_tokens = split_strings_for_bpe(inputs, self.unsplittable_tokens)
137+
token_row_splits = raw_tokens.row_splits
138+
flat_tokens = raw_tokens.flat_values
139+
# Check cache.
140+
cache_lookup = self.cache.lookup(flat_tokens)
141+
cache_mask = cache_lookup == ""
142+
143+
has_unseen_words = tf.math.reduce_any(
144+
(cache_lookup == "") & (flat_tokens != "")
145+
)
146+
147+
def process_unseen_tokens():
148+
unseen_tokens = tf.boolean_mask(flat_tokens, cache_mask)
149+
self._bpe_merge_and_update_cache(unseen_tokens)
150+
return self.cache.lookup(flat_tokens)
151+
152+
# If `has_unseen_words == True`, it means not all tokens are in cache,
153+
# we will process the unseen tokens. Otherwise return the cache lookup.
154+
tokenized_words = tf.cond(
155+
has_unseen_words,
156+
process_unseen_tokens,
157+
lambda: cache_lookup,
158+
)
159+
tokens = tf.strings.split(tokenized_words, sep=" ")
160+
if self.compute_dtype != tf.string:
161+
# Encode merged tokens.
162+
tokens = self.token_to_id_map.lookup(tokens)
163+
164+
# Unflatten to match input.
165+
tokens = tf.RaggedTensor.from_row_splits(
166+
tokens.flat_values,
167+
tf.gather(tokens.row_splits, token_row_splits),
114168
)
115169

116-
def _bpe_merge_and_update_cache(self, tokens):
117-
"""Process unseen tokens and add to cache."""
118-
words = self._transform_bytes(tokens)
119-
tokenized_words = self._bpe_merge(words)
120-
121-
# For each word, join all its token by a whitespace,
122-
# e.g., ["dragon", "fly"] => "dragon fly" for hash purpose.
123-
tokenized_words = tf.strings.reduce_join(
124-
tokenized_words,
125-
axis=1,
126-
)
127-
self.cache.insert(tokens, tokenized_words)
128-
129-
def tokenize(self, inputs):
130-
self._check_vocabulary()
131-
if not isinstance(inputs, (tf.Tensor, tf.RaggedTensor)):
132-
inputs = tf.convert_to_tensor(inputs)
133-
134-
if self.add_prefix_space:
135-
inputs = tf.strings.join([" ", inputs])
136-
137-
scalar_input = inputs.shape.rank == 0
138-
if scalar_input:
139-
inputs = tf.expand_dims(inputs, 0)
140-
141-
raw_tokens = split_strings_for_bpe(inputs, self.unsplittable_tokens)
142-
token_row_splits = raw_tokens.row_splits
143-
flat_tokens = raw_tokens.flat_values
144-
# Check cache.
145-
cache_lookup = self.cache.lookup(flat_tokens)
146-
cache_mask = cache_lookup == ""
147-
148-
has_unseen_words = tf.math.reduce_any(
149-
(cache_lookup == "") & (flat_tokens != "")
150-
)
151-
152-
def process_unseen_tokens():
153-
unseen_tokens = tf.boolean_mask(flat_tokens, cache_mask)
154-
self._bpe_merge_and_update_cache(unseen_tokens)
155-
return self.cache.lookup(flat_tokens)
156-
157-
# If `has_unseen_words == True`, it means not all tokens are in cache,
158-
# we will process the unseen tokens. Otherwise return the cache lookup.
159-
tokenized_words = tf.cond(
160-
has_unseen_words,
161-
process_unseen_tokens,
162-
lambda: cache_lookup,
163-
)
164-
tokens = tf.strings.split(tokenized_words, sep=" ")
165-
if self.compute_dtype != tf.string:
166-
# Encode merged tokens.
167-
tokens = self.token_to_id_map.lookup(tokens)
168-
169-
# Unflatten to match input.
170-
tokens = tf.RaggedTensor.from_row_splits(
171-
tokens.flat_values,
172-
tf.gather(tokens.row_splits, token_row_splits),
173-
)
174-
175-
# Convert to a dense output if `sequence_length` is set.
176-
if self.sequence_length:
177-
output_shape = tokens.shape.as_list()
178-
output_shape[-1] = self.sequence_length
179-
tokens = tokens.to_tensor(shape=output_shape)
180-
181-
# Convert to a dense output if input in scalar
182-
if scalar_input:
183-
tokens = tf.squeeze(tokens, 0)
184-
tf.ensure_shape(tokens, shape=[self.sequence_length])
185-
186-
return tokens
170+
# Convert to a dense output if `sequence_length` is set.
171+
if self.sequence_length:
172+
output_shape = tokens.shape.as_list()
173+
output_shape[-1] = self.sequence_length
174+
tokens = tokens.to_tensor(shape=output_shape)
175+
176+
# Convert to a dense output if input in scalar
177+
if scalar_input:
178+
tokens = tf.squeeze(tokens, 0)
179+
tf.ensure_shape(tokens, shape=[self.sequence_length])
180+
181+
return tokens
182+
183+
else:
184+
CLIPTokenizer = None

keras_cv/utils/conditional_imports.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,11 @@
3333
except ImportError:
3434
pycocotools = None
3535

36+
try:
37+
import keras_nlp
38+
except ImportError:
39+
keras_nlp = None
40+
3641

3742
def assert_cv2_installed(symbol_name):
3843
if cv2 is None:
@@ -70,3 +75,12 @@ def assert_pycocotools_installed(symbol_name):
7075
"Please install the package using "
7176
"`pip install pycocotools`."
7277
)
78+
79+
80+
def assert_keras_nlp_installed(symbol_name):
81+
if keras_nlp is None:
82+
raise ImportError(
83+
f"{symbol_name} requires the `keras_nlp` package. "
84+
"Please install the package using "
85+
"`pip install keras_nlp`."
86+
)

0 commit comments

Comments
 (0)