Skip to content

Commit 25580a9

Browse files
committed
Phi3V python api and tests
Signed-off-by: Prabod Rathnayaka <[email protected]>
1 parent f2e3480 commit 25580a9

File tree

4 files changed

+419
-1
lines changed

4 files changed

+419
-1
lines changed

python/sparknlp/annotator/cv/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,4 @@
1616
from sparknlp.annotator.cv.convnext_for_image_classification import *
1717
from sparknlp.annotator.cv.vision_encoder_decoder_for_image_captioning import *
1818
from sparknlp.annotator.cv.clip_for_zero_shot_classification import *
19-
from sparknlp.annotator.cv.blip_for_question_answering import *
19+
from sparknlp.annotator.cv.phi3_vision_for_multimodal import *
Lines changed: 329 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,329 @@
1+
# Copyright 2017-2024 John Snow Labs
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from sparknlp.common import *
16+
17+
class Phi3Vision(AnnotatorModel,
18+
HasBatchedAnnotateImage,
19+
HasImageFeatureProperties,
20+
HasEngine,
21+
HasCandidateLabelsProperties,
22+
HasRescaleFactor):
23+
"""BLIPForQuestionAnswering can load BLIP models for visual question answering.
24+
The model consists of a vision encoder, a text encoder as well as a text decoder.
25+
The vision encoder will encode the input image, the text encoder will encode the input question together
26+
with the encoding of the image, and the text decoder will output the answer to the question.
27+
28+
Pretrained models can be loaded with :meth:`.pretrained` of the companion
29+
object:
30+
31+
>>> visualQAClassifier = BLIPForQuestionAnswering.pretrained() \\
32+
... .setInputCols(["image_assembler"]) \\
33+
... .setOutputCol("answer")
34+
35+
The default model is ``"blip_vqa_base"``, if no name is
36+
provided.
37+
38+
For available pretrained models please see the `Models Hub
39+
<https://sparknlp.org/models?task=Question+Answering>`__.
40+
41+
To see which models are compatible and how to import them see
42+
`Import Transformers into Spark NLP 🚀
43+
<https://github.com/JohnSnowLabs/spark-nlp/discussions/5669>`_.
44+
45+
====================== ======================
46+
Input Annotation types Output Annotation type
47+
====================== ======================
48+
``IMAGE`` ``DOCUMENT``
49+
====================== ======================
50+
51+
Parameters
52+
----------
53+
batchSize
54+
Batch size. Large values allows faster processing but requires more
55+
memory, by default 2
56+
configProtoBytes
57+
ConfigProto from tensorflow, serialized into byte array.
58+
maxSentenceLength
59+
Max sentence length to process, by default 50
60+
61+
Examples
62+
--------
63+
>>> import sparknlp
64+
>>> from sparknlp.base import *
65+
>>> from sparknlp.annotator import *
66+
>>> from pyspark.ml import Pipeline
67+
>>> image_df = SparkSessionForTest.spark.read.format("image").load(path=images_path)
68+
>>> test_df = image_df.withColumn("text", lit("What's this picture about?"))
69+
>>> imageAssembler = ImageAssembler() \\
70+
... .setInputCol("image") \\
71+
... .setOutputCol("image_assembler")
72+
>>> visualQAClassifier = BLIPForQuestionAnswering.pretrained() \\
73+
... .setInputCols("image_assembler") \\
74+
... .setOutputCol("answer") \\
75+
... .setSize(384)
76+
>>> pipeline = Pipeline().setStages([
77+
... imageAssembler,
78+
... visualQAClassifier
79+
... ])
80+
>>> result = pipeline.fit(test_df).transform(test_df)
81+
>>> result.select("image_assembler.origin", "answer.result").show(false)
82+
+--------------------------------------+------+
83+
|origin |result|
84+
+--------------------------------------+------+
85+
|[file:///content/images/cat_image.jpg]|[cats]|
86+
+--------------------------------------+------+
87+
"""
88+
89+
name = "Phi3Vision"
90+
91+
inputAnnotatorTypes = [AnnotatorType.IMAGE]
92+
93+
outputAnnotatorType = AnnotatorType.DOCUMENT
94+
95+
configProtoBytes = Param(Params._dummy(),
96+
"configProtoBytes",
97+
"ConfigProto from tensorflow, serialized into byte array. Get with "
98+
"config_proto.SerializeToString()",
99+
TypeConverters.toListInt)
100+
101+
minOutputLength = Param(Params._dummy(), "minOutputLength", "Minimum length of the sequence to be generated",
102+
typeConverter=TypeConverters.toInt)
103+
104+
maxOutputLength = Param(Params._dummy(), "maxOutputLength", "Maximum length of output text",
105+
typeConverter=TypeConverters.toInt)
106+
107+
doSample = Param(Params._dummy(), "doSample", "Whether or not to use sampling; use greedy decoding otherwise",
108+
typeConverter=TypeConverters.toBoolean)
109+
110+
temperature = Param(Params._dummy(), "temperature", "The value used to module the next token probabilities",
111+
typeConverter=TypeConverters.toFloat)
112+
113+
topK = Param(Params._dummy(), "topK",
114+
"The number of highest probability vocabulary tokens to keep for top-k-filtering",
115+
typeConverter=TypeConverters.toInt)
116+
117+
topP = Param(Params._dummy(), "topP",
118+
"If set to float < 1, only the most probable tokens with probabilities that add up to ``top_p`` or higher are kept for generation",
119+
typeConverter=TypeConverters.toFloat)
120+
121+
repetitionPenalty = Param(Params._dummy(), "repetitionPenalty",
122+
"The parameter for repetition penalty. 1.0 means no penalty. See `this paper <https://arxiv.org/pdf/1909.05858.pdf>`__ for more details",
123+
typeConverter=TypeConverters.toFloat)
124+
125+
noRepeatNgramSize = Param(Params._dummy(), "noRepeatNgramSize",
126+
"If set to int > 0, all ngrams of that size can only occur once",
127+
typeConverter=TypeConverters.toInt)
128+
129+
ignoreTokenIds = Param(Params._dummy(), "ignoreTokenIds",
130+
"A list of token ids which are ignored in the decoder's output",
131+
typeConverter=TypeConverters.toListInt)
132+
beamSize = Param(Params._dummy(), "beamSize",
133+
"The Number of beams for beam search.",
134+
typeConverter=TypeConverters.toInt)
135+
136+
def setMaxSentenceSize(self, value):
137+
"""Sets Maximum sentence length that the annotator will process, by
138+
default 50.
139+
140+
Parameters
141+
----------
142+
value : int
143+
Maximum sentence length that the annotator will process
144+
"""
145+
return self._set(maxSentenceLength=value)
146+
147+
def setIgnoreTokenIds(self, value):
148+
"""A list of token ids which are ignored in the decoder's output.
149+
150+
Parameters
151+
----------
152+
value : List[int]
153+
The words to be filtered out
154+
"""
155+
return self._set(ignoreTokenIds=value)
156+
157+
def setConfigProtoBytes(self, b):
158+
"""Sets configProto from tensorflow, serialized into byte array.
159+
160+
Parameters
161+
----------
162+
b : List[int]
163+
ConfigProto from tensorflow, serialized into byte array
164+
"""
165+
return self._set(configProtoBytes=b)
166+
167+
def setMinOutputLength(self, value):
168+
"""Sets minimum length of the sequence to be generated.
169+
170+
Parameters
171+
----------
172+
value : int
173+
Minimum length of the sequence to be generated
174+
"""
175+
return self._set(minOutputLength=value)
176+
177+
def setMaxOutputLength(self, value):
178+
"""Sets maximum length of output text.
179+
180+
Parameters
181+
----------
182+
value : int
183+
Maximum length of output text
184+
"""
185+
return self._set(maxOutputLength=value)
186+
187+
def setDoSample(self, value):
188+
"""Sets whether or not to use sampling, use greedy decoding otherwise.
189+
190+
Parameters
191+
----------
192+
value : bool
193+
Whether or not to use sampling; use greedy decoding otherwise
194+
"""
195+
return self._set(doSample=value)
196+
197+
def setTemperature(self, value):
198+
"""Sets the value used to module the next token probabilities.
199+
200+
Parameters
201+
----------
202+
value : float
203+
The value used to module the next token probabilities
204+
"""
205+
return self._set(temperature=value)
206+
207+
def setTopK(self, value):
208+
"""Sets the number of highest probability vocabulary tokens to keep for
209+
top-k-filtering.
210+
211+
Parameters
212+
----------
213+
value : int
214+
Number of highest probability vocabulary tokens to keep
215+
"""
216+
return self._set(topK=value)
217+
218+
def setTopP(self, value):
219+
"""Sets the top cumulative probability for vocabulary tokens.
220+
221+
If set to float < 1, only the most probable tokens with probabilities
222+
that add up to ``topP`` or higher are kept for generation.
223+
224+
Parameters
225+
----------
226+
value : float
227+
Cumulative probability for vocabulary tokens
228+
"""
229+
return self._set(topP=value)
230+
231+
def setRepetitionPenalty(self, value):
232+
"""Sets the parameter for repetition penalty. 1.0 means no penalty.
233+
234+
Parameters
235+
----------
236+
value : float
237+
The repetition penalty
238+
239+
References
240+
----------
241+
See `Ctrl: A Conditional Transformer Language Model For Controllable
242+
Generation <https://arxiv.org/pdf/1909.05858.pdf>`__ for more details.
243+
"""
244+
return self._set(repetitionPenalty=value)
245+
246+
def setNoRepeatNgramSize(self, value):
247+
"""Sets size of n-grams that can only occur once.
248+
249+
If set to int > 0, all ngrams of that size can only occur once.
250+
251+
Parameters
252+
----------
253+
value : int
254+
N-gram size can only occur once
255+
"""
256+
return self._set(noRepeatNgramSize=value)
257+
258+
def setBeamSize(self, value):
259+
"""Sets the number of beam size for beam search, by default `4`.
260+
261+
Parameters
262+
----------
263+
value : int
264+
Number of beam size for beam search
265+
"""
266+
return self._set(beamSize=value)
267+
@keyword_only
268+
def __init__(self, classname="com.johnsnowlabs.nlp.annotators.cv.Phi3Vision",
269+
java_model=None):
270+
super(Phi3Vision, self).__init__(
271+
classname=classname,
272+
java_model=java_model
273+
)
274+
self._setDefault(
275+
batchSize=2,
276+
minOutputLength=0,
277+
maxOutputLength=200,
278+
doSample=False,
279+
temperature=1,
280+
topK=50,
281+
topP=1,
282+
repetitionPenalty=1.0,
283+
noRepeatNgramSize=0,
284+
ignoreTokenIds=[],
285+
beamSize=1,
286+
)
287+
288+
@staticmethod
289+
def loadSavedModel(folder, spark_session, use_openvino=False):
290+
"""Loads a locally saved model.
291+
292+
Parameters
293+
----------
294+
folder : str
295+
Folder of the saved model
296+
spark_session : pyspark.sql.SparkSession
297+
The current SparkSession
298+
299+
Returns
300+
-------
301+
CLIPForZeroShotClassification
302+
The restored model
303+
"""
304+
from sparknlp.internal import _Phi3VisionLoader
305+
jModel = _Phi3VisionLoader(folder, spark_session._jsparkSession, use_openvino)._java_obj
306+
return Phi3Vision(java_model=jModel)
307+
308+
@staticmethod
309+
def pretrained(name="phi3v", lang="en", remote_loc=None):
310+
"""Downloads and loads a pretrained model.
311+
312+
Parameters
313+
----------
314+
name : str, optional
315+
Name of the pretrained model, by default
316+
"blip_vqa_tf"
317+
lang : str, optional
318+
Language of the pretrained model, by default "en"
319+
remote_loc : str, optional
320+
Optional remote address of the resource, by default None. Will use
321+
Spark NLPs repositories otherwise.
322+
323+
Returns
324+
-------
325+
CLIPForZeroShotClassification
326+
The restored model
327+
"""
328+
from sparknlp.pretrained import ResourceDownloader
329+
return ResourceDownloader.downloadModel(Phi3Vision, name, lang, remote_loc)

python/sparknlp/internal/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,15 @@ def __init__(self, path, jspark, use_openvino=False):
363363
use_openvino,
364364
)
365365

366+
class _Phi3VisionLoader(ExtendedJavaWrapper):
367+
def __init__(self, path, jspark, use_openvino=False):
368+
super(_Phi3VisionLoader, self).__init__(
369+
"com.johnsnowlabs.nlp.annotators.cv.Phi3Vision.loadSavedModel",
370+
path,
371+
jspark,
372+
use_openvino
373+
)
374+
366375
class _RoBertaLoader(ExtendedJavaWrapper):
367376
def __init__(self, path, jspark, use_openvino=False):
368377
super(_RoBertaLoader, self).__init__(

0 commit comments

Comments
 (0)