Skip to content

Commit aefe88f

Browse files
authored
Addition chunk metadata to ChunkEmbeddings output (#14462)
* create AnnotationUtils helper class for testing * create AnnotationUtils helper class for testing
1 parent 3302205 commit aefe88f

File tree

3 files changed

+99
-2
lines changed

3 files changed

+99
-2
lines changed

src/main/scala/com/johnsnowlabs/nlp/embeddings/ChunkEmbeddings.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@ class ChunkEmbeddings(override val uid: String)
260260
begin = chunk.begin,
261261
end = chunk.end,
262262
result = chunk.result,
263-
metadata = Map(
263+
metadata = chunk.metadata ++ Map(
264264
"sentence" -> sentenceIdx.toString,
265265
"chunk" -> chunkIdx.toString,
266266
"token" -> chunk.result,
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
package com.johnsnowlabs.nlp
2+
3+
import com.johnsnowlabs.nlp.AnnotatorType.DOCUMENT
4+
import org.apache.spark.sql.types.{MetadataBuilder, StructField, StructType}
5+
import org.apache.spark.sql.{DataFrame, Row}
6+
7+
object AnnotationUtils {
8+
9+
private lazy val spark = SparkAccessor.spark
10+
11+
implicit class AnnotationRow(annotation: Annotation) {
12+
13+
def toRow(): Row = {
14+
Row(
15+
annotation.annotatorType,
16+
annotation.begin,
17+
annotation.end,
18+
annotation.result,
19+
annotation.metadata,
20+
annotation.embeddings)
21+
}
22+
}
23+
24+
implicit class DocumentRow(s: String) {
25+
def toRow(metadata: Map[String, String] = Map("sentence" -> "0")): Row = {
26+
Row(Seq(Annotation(DOCUMENT, 0, s.length, s, metadata).toRow()))
27+
}
28+
}
29+
30+
/** Create a DataFrame with the given column name, annotator type and annotations row Output
31+
* column will be compatible with the Spark NLP annotators
32+
*/
33+
def createAnnotatorDataframe(
34+
columnName: String,
35+
annotatorType: String,
36+
annotationsRow: Row): DataFrame = {
37+
val metadataBuilder: MetadataBuilder = new MetadataBuilder()
38+
metadataBuilder.putString("annotatorType", annotatorType)
39+
val documentField =
40+
StructField(columnName, Annotation.arrayType, nullable = false, metadataBuilder.build)
41+
val struct = StructType(Array(documentField))
42+
val rdd = spark.sparkContext.parallelize(Seq(annotationsRow))
43+
spark.createDataFrame(rdd, struct)
44+
}
45+
46+
}

src/test/scala/com/johnsnowlabs/nlp/embeddings/ChunkEmbeddingsTestSpec.scala

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,16 @@
1616

1717
package com.johnsnowlabs.nlp.embeddings
1818

19+
import com.johnsnowlabs.nlp.AnnotatorType.{CHUNK, DOCUMENT}
1920
import com.johnsnowlabs.nlp.annotator.{Chunker, PerceptronModel}
2021
import com.johnsnowlabs.nlp.annotators.sbd.pragmatic.SentenceDetector
2122
import com.johnsnowlabs.nlp.annotators.{NGramGenerator, StopWordsCleaner, Tokenizer}
2223
import com.johnsnowlabs.nlp.base.DocumentAssembler
2324
import com.johnsnowlabs.nlp.util.io.ResourceHelper
24-
import com.johnsnowlabs.nlp.{AnnotatorBuilder, EmbeddingsFinisher, Finisher}
25+
import com.johnsnowlabs.nlp.{Annotation, AnnotatorBuilder, EmbeddingsFinisher, Finisher}
2526
import com.johnsnowlabs.tags.FastTest
2627
import org.apache.spark.ml.Pipeline
28+
import org.apache.spark.sql.Row
2729
import org.scalatest.flatspec.AnyFlatSpec
2830

2931
class ChunkEmbeddingsTestSpec extends AnyFlatSpec {
@@ -266,4 +268,53 @@ class ChunkEmbeddingsTestSpec extends AnyFlatSpec {
266268

267269
}
268270

271+
"ChunkEmbeddings" should "return chunk metadata at output" taggedAs FastTest in {
272+
import com.johnsnowlabs.nlp.AnnotationUtils._
273+
val document = "Record: Bush Blue, ZIPCODE: XYZ84556222, phone: (911) 45 88".toRow()
274+
275+
val chunks = Row(
276+
Seq(
277+
Annotation(
278+
CHUNK,
279+
8,
280+
16,
281+
"Bush Blue",
282+
Map("entity" -> "NAME", "sentence" -> "0", "chunk" -> "0", "confidence" -> "0.98"))
283+
.toRow(),
284+
Annotation(
285+
CHUNK,
286+
48,
287+
58,
288+
"(911) 45 88",
289+
Map("entity" -> "PHONE", "sentence" -> "0", "chunk" -> "1", "confidence" -> "1.0"))
290+
.toRow()))
291+
292+
val df = createAnnotatorDataframe("sentence", DOCUMENT, document)
293+
.crossJoin(createAnnotatorDataframe("chunk", CHUNK, chunks))
294+
295+
val token = new Tokenizer()
296+
.setInputCols("sentence")
297+
.setOutputCol("token")
298+
299+
val wordEmbeddings = WordEmbeddingsModel
300+
.pretrained()
301+
.setInputCols("sentence", "token")
302+
.setOutputCol("embeddings")
303+
304+
val chunkEmbeddings = new ChunkEmbeddings()
305+
.setInputCols("chunk", "embeddings")
306+
.setOutputCol("chunk_embeddings")
307+
.setPoolingStrategy("AVERAGE")
308+
309+
val pipeline = new Pipeline().setStages(Array(token, wordEmbeddings, chunkEmbeddings))
310+
val result_df = pipeline.fit(df).transform(df)
311+
// result_df.selectExpr("explode(chunk_embeddings) as embeddings").show(false)
312+
val annotations = Annotation.collect(result_df, "chunk_embeddings").flatten
313+
assert(annotations.length == 2)
314+
assert(annotations(0).metadata("entity") == "NAME")
315+
assert(annotations(1).metadata("entity") == "PHONE")
316+
val expectedMetadataKeys = Set("entity", "sentence", "chunk", "confidence")
317+
assert(annotations.forall(anno => expectedMetadataKeys.forall(anno.metadata.contains)))
318+
}
319+
269320
}

0 commit comments

Comments
 (0)