Skip to content

Commit 11a7fd9

Browse files
danilojslprabod
authored andcommitted
Introducing BertForMultipleChoice transformer (#14435)
* [SPARKNLP-1084] Introducing BertForMultipleChoice * [SPARKNLP-1084] Introducing BertForMultipleChoice transformer
1 parent 4be100e commit 11a7fd9

File tree

10 files changed

+860
-13
lines changed

10 files changed

+860
-13
lines changed

python/sparknlp/annotator/classifier_dl/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,4 +54,4 @@
5454
from sparknlp.annotator.classifier_dl.mpnet_for_token_classification import *
5555
from sparknlp.annotator.classifier_dl.albert_for_zero_shot_classification import *
5656
from sparknlp.annotator.classifier_dl.camembert_for_zero_shot_classification import *
57-
57+
from sparknlp.annotator.classifier_dl.bert_for_multiple_choice import *
Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
# Copyright 2017-2024 John Snow Labs
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from sparknlp.common import *
16+
17+
class BertForMultipleChoice(AnnotatorModel,
18+
HasCaseSensitiveProperties,
19+
HasBatchedAnnotate,
20+
HasEngine,
21+
HasMaxSentenceLengthLimit):
22+
"""BertForMultipleChoice can load BERT Models with a multiple choice classification head on top
23+
(a linear layer on top of the pooled output and a softmax) e.g. for RocStories/SWAG tasks.
24+
25+
Pretrained models can be loaded with :meth:`.pretrained` of the companion
26+
object:
27+
28+
>>> spanClassifier = BertForMultipleChoice.pretrained() \\
29+
... .setInputCols(["document_question", "document_context"]) \\
30+
... .setOutputCol("answer")
31+
32+
The default model is ``"bert_base_uncased_multiple_choice"``, if no name is
33+
provided.
34+
35+
For available pretrained models please see the `Models Hub
36+
<https://sparknlp.org/models?task=Multiple+Choice>`__.
37+
38+
To see which models are compatible and how to import them see
39+
`Import Transformers into Spark NLP 🚀
40+
<https://github.com/JohnSnowLabs/spark-nlp/discussions/5669>`_.
41+
42+
====================== ======================
43+
Input Annotation types Output Annotation type
44+
====================== ======================
45+
``DOCUMENT, DOCUMENT`` ``CHUNK``
46+
====================== ======================
47+
48+
Parameters
49+
----------
50+
batchSize
51+
Batch size. Large values allows faster processing but requires more
52+
memory, by default 8
53+
caseSensitive
54+
Whether to ignore case in tokens for embeddings matching, by default
55+
False
56+
maxSentenceLength
57+
Max sentence length to process, by default 512
58+
59+
Examples
60+
--------
61+
>>> import sparknlp
62+
>>> from sparknlp.base import *
63+
>>> from sparknlp.annotator import *
64+
>>> from pyspark.ml import Pipeline
65+
>>> documentAssembler = MultiDocumentAssembler() \\
66+
... .setInputCols(["question", "context"]) \\
67+
... .setOutputCols(["document_question", "document_context"])
68+
>>> questionAnswering = BertForMultipleChoice.pretrained() \\
69+
... .setInputCols(["document_question", "document_context"]) \\
70+
... .setOutputCol("answer") \\
71+
... .setCaseSensitive(False)
72+
>>> pipeline = Pipeline().setStages([
73+
... documentAssembler,
74+
... questionAnswering
75+
... ])
76+
>>> data = spark.createDataFrame([["The Eiffel Tower is located in which country??", "Germany, France, Italy"]]).toDF("question", "context")
77+
>>> result = pipeline.fit(data).transform(data)
78+
>>> result.select("answer.result").show(truncate=False)
79+
+--------------------+
80+
|result |
81+
+--------------------+
82+
|[France] |
83+
+--------------------+
84+
"""
85+
name = "BertForMultipleChoice"
86+
87+
inputAnnotatorTypes = [AnnotatorType.DOCUMENT, AnnotatorType.DOCUMENT]
88+
89+
outputAnnotatorType = AnnotatorType.CHUNK
90+
91+
choicesDelimiter = Param(Params._dummy(),
92+
"choicesDelimiter",
93+
"Delimiter character use to split the choices",
94+
TypeConverters.toString)
95+
96+
def setChoicesDelimiter(self, value):
97+
"""Sets delimiter character use to split the choices
98+
99+
Parameters
100+
----------
101+
value : string
102+
Delimiter character use to split the choices
103+
"""
104+
return self._set(caseSensitive=value)
105+
106+
@keyword_only
107+
def __init__(self, classname="com.johnsnowlabs.nlp.annotators.classifier.dl.BertForMultipleChoice",
108+
java_model=None):
109+
super(BertForMultipleChoice, self).__init__(
110+
classname=classname,
111+
java_model=java_model
112+
)
113+
self._setDefault(
114+
batchSize=4,
115+
maxSentenceLength=512,
116+
caseSensitive=False,
117+
choicesDelimiter = ","
118+
)
119+
120+
@staticmethod
121+
def loadSavedModel(folder, spark_session):
122+
"""Loads a locally saved model.
123+
124+
Parameters
125+
----------
126+
folder : str
127+
Folder of the saved model
128+
spark_session : pyspark.sql.SparkSession
129+
The current SparkSession
130+
131+
Returns
132+
-------
133+
BertForQuestionAnswering
134+
The restored model
135+
"""
136+
from sparknlp.internal import _BertMultipleChoiceLoader
137+
jModel = _BertMultipleChoiceLoader(folder, spark_session._jsparkSession)._java_obj
138+
return BertForMultipleChoice(java_model=jModel)
139+
140+
@staticmethod
141+
def pretrained(name="bert_base_uncased_multiple_choice", lang="en", remote_loc=None):
142+
"""Downloads and loads a pretrained model.
143+
144+
Parameters
145+
----------
146+
name : str, optional
147+
Name of the pretrained model, by default
148+
"bert_base_uncased_multiple_choice"
149+
lang : str, optional
150+
Language of the pretrained model, by default "en"
151+
remote_loc : str, optional
152+
Optional remote address of the resource, by default None. Will use
153+
Spark NLPs repositories otherwise.
154+
155+
Returns
156+
-------
157+
BertForQuestionAnswering
158+
The restored model
159+
"""
160+
from sparknlp.pretrained import ResourceDownloader
161+
return ResourceDownloader.downloadModel(BertForMultipleChoice, name, lang, remote_loc)

python/sparknlp/internal/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,13 @@ def __init__(self, path, jspark):
113113
jspark,
114114
)
115115

116+
class _BertMultipleChoiceLoader(ExtendedJavaWrapper):
117+
def __init__(self, path, jspark):
118+
super(_BertMultipleChoiceLoader, self).__init__(
119+
"com.johnsnowlabs.nlp.annotators.classifier.dl.BertForMultipleChoice.loadSavedModel",
120+
path,
121+
jspark,
122+
)
116123

117124
class _DeBERTaLoader(ExtendedJavaWrapper):
118125
def __init__(self, path, jspark):
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# Copyright 2017-2024 John Snow Labs
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
17+
import pytest
18+
19+
from sparknlp.annotator import *
20+
from sparknlp.base import *
21+
from test.util import SparkContextForTest
22+
23+
24+
class BertForMultipleChoiceTestSetup(unittest.TestCase):
25+
def setUp(self):
26+
self.spark = SparkContextForTest.spark
27+
self.question = "The Eiffel Tower is located in which country?"
28+
self.choices = "Germany, France, Italy"
29+
30+
self.spark = SparkContextForTest.spark
31+
empty_df = self.spark.createDataFrame([[""]]).toDF("text")
32+
33+
document_assembler = MultiDocumentAssembler() \
34+
.setInputCols(["question", "context"]) \
35+
.setOutputCols(["document_question", "document_context"])
36+
37+
bert_for_multiple_choice = BertForMultipleChoice.pretrained() \
38+
.setInputCols(["document_question", "document_context"]) \
39+
.setOutputCol("answer") \
40+
41+
pipeline = Pipeline(stages=[document_assembler, bert_for_multiple_choice])
42+
43+
self.pipeline_model = pipeline.fit(empty_df)
44+
45+
46+
@pytest.mark.slow
47+
class BertForMultipleChoiceTest(BertForMultipleChoiceTestSetup, unittest.TestCase):
48+
49+
def setUp(self):
50+
super().setUp()
51+
self.data = self.spark.createDataFrame([[self.question, self.choices]]).toDF("question","context")
52+
self.data.show(truncate=False)
53+
54+
def test_run(self):
55+
result_df = self.pipeline_model.transform(self.data)
56+
result_df.show(truncate=False)
57+
for row in result_df.collect():
58+
self.assertTrue(row["answer"][0].result != "")
59+
60+
61+
@pytest.mark.slow
62+
class LightBertForMultipleChoiceTest(BertForMultipleChoiceTestSetup, unittest.TestCase):
63+
64+
def setUp(self):
65+
super().setUp()
66+
67+
def runTest(self):
68+
light_pipeline = LightPipeline(self.pipeline_model)
69+
annotations_result = light_pipeline.fullAnnotate(self.question,self.choices)
70+
print(annotations_result)
71+
for result in annotations_result:
72+
self.assertTrue(result["answer"][0].result != "")
73+
74+
result = light_pipeline.annotate(self.question,self.choices)
75+
print(result)
76+
self.assertTrue(result["answer"] != "")

src/main/scala/com/johnsnowlabs/ml/ai/BertClassification.scala

Lines changed: 91 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ private[johnsnowlabs] class BertClassification(
130130

131131
// we need the original form of the token
132132
// let's lowercase if needed right before the encoding
133-
val basicTokenizer = new BasicTokenizer(caseSensitive = true, hasBeginEnd = false)
133+
val basicTokenizer = new BasicTokenizer(caseSensitive = caseSensitive, hasBeginEnd = false)
134134
val encoder = new WordpieceEncoder(vocabulary)
135135
val sentences = docs.map { s => Sentence(s.result, s.begin, s.end, 0) }
136136

@@ -546,6 +546,15 @@ private[johnsnowlabs] class BertClassification(
546546
(startScores, endScores)
547547
}
548548

549+
override def tagSpanMultipleChoice(batch: Seq[Array[Int]]): Array[Float] = {
550+
val logits = detectedEngine match {
551+
case ONNX.name => computeLogitsMultipleChoiceWithOnnx(batch)
552+
case Openvino.name => computeLogitsMultipleChoiceWithOv(batch)
553+
}
554+
555+
calculateSoftmax(logits)
556+
}
557+
549558
private def computeLogitsWithTF(
550559
batch: Seq[Array[Int]],
551560
maxSentenceLength: Int): (Array[Float], Array[Float]) = {
@@ -732,6 +741,87 @@ private[johnsnowlabs] class BertClassification(
732741
}
733742
}
734743

744+
private def computeLogitsMultipleChoiceWithOnnx(batch: Seq[Array[Int]]): Array[Float] = {
745+
val sequenceLength = batch.head.length
746+
val inputIds = Array(batch.map(x => x.map(_.toLong)).toArray)
747+
val attentionMask = Array(
748+
batch.map(sentence => sentence.map(x => if (x == 0L) 0L else 1L)).toArray)
749+
val tokenTypeIds = Array(batch.map(_ => Array.fill(sequenceLength)(0L)).toArray)
750+
751+
val (ortSession, ortEnv) = onnxWrapper.get.getSession(onnxSessionOptions)
752+
val tokenTensors = OnnxTensor.createTensor(ortEnv, inputIds)
753+
val maskTensors = OnnxTensor.createTensor(ortEnv, attentionMask)
754+
val segmentTensors = OnnxTensor.createTensor(ortEnv, tokenTypeIds)
755+
756+
val inputs =
757+
Map(
758+
"input_ids" -> tokenTensors,
759+
"attention_mask" -> maskTensors,
760+
"token_type_ids" -> segmentTensors).asJava
761+
762+
try {
763+
val output = ortSession.run(inputs)
764+
try {
765+
766+
val logits = output
767+
.get("logits")
768+
.get()
769+
.asInstanceOf[OnnxTensor]
770+
.getFloatBuffer
771+
.array()
772+
773+
tokenTensors.close()
774+
maskTensors.close()
775+
segmentTensors.close()
776+
777+
logits
778+
} finally if (output != null) output.close()
779+
} catch {
780+
case e: Exception =>
781+
// Log the exception as a warning
782+
println("Exception in computeLogitsMultipleChoiceWithOnnx: ", e)
783+
// Rethrow the exception to propagate it further
784+
throw e
785+
}
786+
}
787+
788+
private def computeLogitsMultipleChoiceWithOv(batch: Seq[Array[Int]]): Array[Float] = {
789+
val (numChoices, sequenceLength) = (batch.length, batch.head.length)
790+
// batch_size, num_choices, sequence_length
791+
val shape = Some(Array(1, numChoices, sequenceLength))
792+
val (tokenTensors, maskTensors, segmentTensors) =
793+
PrepareEmbeddings.prepareOvLongBatchTensorsWithSegment(
794+
batch,
795+
sequenceLength,
796+
numChoices,
797+
sentencePadTokenId,
798+
shape)
799+
800+
val compiledModel = openvinoWrapper.get.getCompiledModel()
801+
val inferRequest = compiledModel.create_infer_request()
802+
inferRequest.set_tensor("input_ids", tokenTensors)
803+
inferRequest.set_tensor("attention_mask", maskTensors)
804+
inferRequest.set_tensor("token_type_ids", segmentTensors)
805+
806+
inferRequest.infer()
807+
808+
try {
809+
try {
810+
val logits = inferRequest
811+
.get_output_tensor()
812+
.data()
813+
814+
logits
815+
}
816+
} catch {
817+
case e: Exception =>
818+
// Log the exception as a warning
819+
logger.warn("Exception in computeLogitsMultipleChoiceWithOv", e)
820+
// Rethrow the exception to propagate it further
821+
throw e
822+
}
823+
}
824+
735825
def findIndexedToken(
736826
tokenizedSentences: Seq[TokenizedSentence],
737827
sentence: (WordpieceTokenizedSentence, Int),

0 commit comments

Comments
 (0)