1+ # Copyright 2017-2024 John Snow Labs
2+ #
3+ # Licensed under the Apache License, Version 2.0 (the "License");
4+ # you may not use this file except in compliance with the License.
5+ # You may obtain a copy of the License at
6+ #
7+ # http://www.apache.org/licenses/LICENSE-2.0
8+ #
9+ # Unless required by applicable law or agreed to in writing, software
10+ # distributed under the License is distributed on an "AS IS" BASIS,
11+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+ # See the License for the specific language governing permissions and
13+ # limitations under the License.
14+
15+ from sparknlp .common import *
16+
17+ class Phi3Vision (AnnotatorModel ,
18+ HasBatchedAnnotateImage ,
19+ HasImageFeatureProperties ,
20+ HasEngine ,
21+ HasCandidateLabelsProperties ,
22+ HasRescaleFactor ):
23+ """BLIPForQuestionAnswering can load BLIP models for visual question answering.
24+ The model consists of a vision encoder, a text encoder as well as a text decoder.
25+ The vision encoder will encode the input image, the text encoder will encode the input question together
26+ with the encoding of the image, and the text decoder will output the answer to the question.
27+
28+ Pretrained models can be loaded with :meth:`.pretrained` of the companion
29+ object:
30+
31+ >>> visualQAClassifier = BLIPForQuestionAnswering.pretrained() \\
32+ ... .setInputCols(["image_assembler"]) \\
33+ ... .setOutputCol("answer")
34+
35+ The default model is ``"blip_vqa_base"``, if no name is
36+ provided.
37+
38+ For available pretrained models please see the `Models Hub
39+ <https://sparknlp.org/models?task=Question+Answering>`__.
40+
41+ To see which models are compatible and how to import them see
42+ `Import Transformers into Spark NLP 🚀
43+ <https://github.com/JohnSnowLabs/spark-nlp/discussions/5669>`_.
44+
45+ ====================== ======================
46+ Input Annotation types Output Annotation type
47+ ====================== ======================
48+ ``IMAGE`` ``DOCUMENT``
49+ ====================== ======================
50+
51+ Parameters
52+ ----------
53+ batchSize
54+ Batch size. Large values allows faster processing but requires more
55+ memory, by default 2
56+ configProtoBytes
57+ ConfigProto from tensorflow, serialized into byte array.
58+ maxSentenceLength
59+ Max sentence length to process, by default 50
60+
61+ Examples
62+ --------
63+ >>> import sparknlp
64+ >>> from sparknlp.base import *
65+ >>> from sparknlp.annotator import *
66+ >>> from pyspark.ml import Pipeline
67+ >>> image_df = SparkSessionForTest.spark.read.format("image").load(path=images_path)
68+ >>> test_df = image_df.withColumn("text", lit("What's this picture about?"))
69+ >>> imageAssembler = ImageAssembler() \\
70+ ... .setInputCol("image") \\
71+ ... .setOutputCol("image_assembler")
72+ >>> visualQAClassifier = BLIPForQuestionAnswering.pretrained() \\
73+ ... .setInputCols("image_assembler") \\
74+ ... .setOutputCol("answer") \\
75+ ... .setSize(384)
76+ >>> pipeline = Pipeline().setStages([
77+ ... imageAssembler,
78+ ... visualQAClassifier
79+ ... ])
80+ >>> result = pipeline.fit(test_df).transform(test_df)
81+ >>> result.select("image_assembler.origin", "answer.result").show(false)
82+ +--------------------------------------+------+
83+ |origin |result|
84+ +--------------------------------------+------+
85+ |[file:///content/images/cat_image.jpg]|[cats]|
86+ +--------------------------------------+------+
87+ """
88+
89+ name = "Phi3Vision"
90+
91+ inputAnnotatorTypes = [AnnotatorType .IMAGE ]
92+
93+ outputAnnotatorType = AnnotatorType .DOCUMENT
94+
95+ configProtoBytes = Param (Params ._dummy (),
96+ "configProtoBytes" ,
97+ "ConfigProto from tensorflow, serialized into byte array. Get with "
98+ "config_proto.SerializeToString()" ,
99+ TypeConverters .toListInt )
100+
101+ minOutputLength = Param (Params ._dummy (), "minOutputLength" , "Minimum length of the sequence to be generated" ,
102+ typeConverter = TypeConverters .toInt )
103+
104+ maxOutputLength = Param (Params ._dummy (), "maxOutputLength" , "Maximum length of output text" ,
105+ typeConverter = TypeConverters .toInt )
106+
107+ doSample = Param (Params ._dummy (), "doSample" , "Whether or not to use sampling; use greedy decoding otherwise" ,
108+ typeConverter = TypeConverters .toBoolean )
109+
110+ temperature = Param (Params ._dummy (), "temperature" , "The value used to module the next token probabilities" ,
111+ typeConverter = TypeConverters .toFloat )
112+
113+ topK = Param (Params ._dummy (), "topK" ,
114+ "The number of highest probability vocabulary tokens to keep for top-k-filtering" ,
115+ typeConverter = TypeConverters .toInt )
116+
117+ topP = Param (Params ._dummy (), "topP" ,
118+ "If set to float < 1, only the most probable tokens with probabilities that add up to ``top_p`` or higher are kept for generation" ,
119+ typeConverter = TypeConverters .toFloat )
120+
121+ repetitionPenalty = Param (Params ._dummy (), "repetitionPenalty" ,
122+ "The parameter for repetition penalty. 1.0 means no penalty. See `this paper <https://arxiv.org/pdf/1909.05858.pdf>`__ for more details" ,
123+ typeConverter = TypeConverters .toFloat )
124+
125+ noRepeatNgramSize = Param (Params ._dummy (), "noRepeatNgramSize" ,
126+ "If set to int > 0, all ngrams of that size can only occur once" ,
127+ typeConverter = TypeConverters .toInt )
128+
129+ ignoreTokenIds = Param (Params ._dummy (), "ignoreTokenIds" ,
130+ "A list of token ids which are ignored in the decoder's output" ,
131+ typeConverter = TypeConverters .toListInt )
132+ beamSize = Param (Params ._dummy (), "beamSize" ,
133+ "The Number of beams for beam search." ,
134+ typeConverter = TypeConverters .toInt )
135+
136+ def setMaxSentenceSize (self , value ):
137+ """Sets Maximum sentence length that the annotator will process, by
138+ default 50.
139+
140+ Parameters
141+ ----------
142+ value : int
143+ Maximum sentence length that the annotator will process
144+ """
145+ return self ._set (maxSentenceLength = value )
146+
147+ def setIgnoreTokenIds (self , value ):
148+ """A list of token ids which are ignored in the decoder's output.
149+
150+ Parameters
151+ ----------
152+ value : List[int]
153+ The words to be filtered out
154+ """
155+ return self ._set (ignoreTokenIds = value )
156+
157+ def setConfigProtoBytes (self , b ):
158+ """Sets configProto from tensorflow, serialized into byte array.
159+
160+ Parameters
161+ ----------
162+ b : List[int]
163+ ConfigProto from tensorflow, serialized into byte array
164+ """
165+ return self ._set (configProtoBytes = b )
166+
167+ def setMinOutputLength (self , value ):
168+ """Sets minimum length of the sequence to be generated.
169+
170+ Parameters
171+ ----------
172+ value : int
173+ Minimum length of the sequence to be generated
174+ """
175+ return self ._set (minOutputLength = value )
176+
177+ def setMaxOutputLength (self , value ):
178+ """Sets maximum length of output text.
179+
180+ Parameters
181+ ----------
182+ value : int
183+ Maximum length of output text
184+ """
185+ return self ._set (maxOutputLength = value )
186+
187+ def setDoSample (self , value ):
188+ """Sets whether or not to use sampling, use greedy decoding otherwise.
189+
190+ Parameters
191+ ----------
192+ value : bool
193+ Whether or not to use sampling; use greedy decoding otherwise
194+ """
195+ return self ._set (doSample = value )
196+
197+ def setTemperature (self , value ):
198+ """Sets the value used to module the next token probabilities.
199+
200+ Parameters
201+ ----------
202+ value : float
203+ The value used to module the next token probabilities
204+ """
205+ return self ._set (temperature = value )
206+
207+ def setTopK (self , value ):
208+ """Sets the number of highest probability vocabulary tokens to keep for
209+ top-k-filtering.
210+
211+ Parameters
212+ ----------
213+ value : int
214+ Number of highest probability vocabulary tokens to keep
215+ """
216+ return self ._set (topK = value )
217+
218+ def setTopP (self , value ):
219+ """Sets the top cumulative probability for vocabulary tokens.
220+
221+ If set to float < 1, only the most probable tokens with probabilities
222+ that add up to ``topP`` or higher are kept for generation.
223+
224+ Parameters
225+ ----------
226+ value : float
227+ Cumulative probability for vocabulary tokens
228+ """
229+ return self ._set (topP = value )
230+
231+ def setRepetitionPenalty (self , value ):
232+ """Sets the parameter for repetition penalty. 1.0 means no penalty.
233+
234+ Parameters
235+ ----------
236+ value : float
237+ The repetition penalty
238+
239+ References
240+ ----------
241+ See `Ctrl: A Conditional Transformer Language Model For Controllable
242+ Generation <https://arxiv.org/pdf/1909.05858.pdf>`__ for more details.
243+ """
244+ return self ._set (repetitionPenalty = value )
245+
246+ def setNoRepeatNgramSize (self , value ):
247+ """Sets size of n-grams that can only occur once.
248+
249+ If set to int > 0, all ngrams of that size can only occur once.
250+
251+ Parameters
252+ ----------
253+ value : int
254+ N-gram size can only occur once
255+ """
256+ return self ._set (noRepeatNgramSize = value )
257+
258+ def setBeamSize (self , value ):
259+ """Sets the number of beam size for beam search, by default `4`.
260+
261+ Parameters
262+ ----------
263+ value : int
264+ Number of beam size for beam search
265+ """
266+ return self ._set (beamSize = value )
267+ @keyword_only
268+ def __init__ (self , classname = "com.johnsnowlabs.nlp.annotators.cv.Phi3Vision" ,
269+ java_model = None ):
270+ super (Phi3Vision , self ).__init__ (
271+ classname = classname ,
272+ java_model = java_model
273+ )
274+ self ._setDefault (
275+ batchSize = 2 ,
276+ minOutputLength = 0 ,
277+ maxOutputLength = 200 ,
278+ doSample = False ,
279+ temperature = 1 ,
280+ topK = 50 ,
281+ topP = 1 ,
282+ repetitionPenalty = 1.0 ,
283+ noRepeatNgramSize = 0 ,
284+ ignoreTokenIds = [],
285+ beamSize = 1 ,
286+ )
287+
288+ @staticmethod
289+ def loadSavedModel (folder , spark_session , use_openvino = False ):
290+ """Loads a locally saved model.
291+
292+ Parameters
293+ ----------
294+ folder : str
295+ Folder of the saved model
296+ spark_session : pyspark.sql.SparkSession
297+ The current SparkSession
298+
299+ Returns
300+ -------
301+ CLIPForZeroShotClassification
302+ The restored model
303+ """
304+ from sparknlp .internal import _Phi3VisionLoader
305+ jModel = _Phi3VisionLoader (folder , spark_session ._jsparkSession , use_openvino )._java_obj
306+ return Phi3Vision (java_model = jModel )
307+
308+ @staticmethod
309+ def pretrained (name = "phi3v" , lang = "en" , remote_loc = None ):
310+ """Downloads and loads a pretrained model.
311+
312+ Parameters
313+ ----------
314+ name : str, optional
315+ Name of the pretrained model, by default
316+ "blip_vqa_tf"
317+ lang : str, optional
318+ Language of the pretrained model, by default "en"
319+ remote_loc : str, optional
320+ Optional remote address of the resource, by default None. Will use
321+ Spark NLPs repositories otherwise.
322+
323+ Returns
324+ -------
325+ CLIPForZeroShotClassification
326+ The restored model
327+ """
328+ from sparknlp .pretrained import ResourceDownloader
329+ return ResourceDownloader .downloadModel (Phi3Vision , name , lang , remote_loc )
0 commit comments