|
16 | 16 |
|
17 | 17 | package com.johnsnowlabs.nlp.embeddings |
18 | 18 |
|
| 19 | +import com.johnsnowlabs.nlp.AnnotatorType.{CHUNK, DOCUMENT} |
19 | 20 | import com.johnsnowlabs.nlp.annotator.{Chunker, PerceptronModel} |
20 | 21 | import com.johnsnowlabs.nlp.annotators.sbd.pragmatic.SentenceDetector |
21 | 22 | import com.johnsnowlabs.nlp.annotators.{NGramGenerator, StopWordsCleaner, Tokenizer} |
22 | 23 | import com.johnsnowlabs.nlp.base.DocumentAssembler |
23 | 24 | import com.johnsnowlabs.nlp.util.io.ResourceHelper |
24 | | -import com.johnsnowlabs.nlp.{AnnotatorBuilder, EmbeddingsFinisher, Finisher} |
| 25 | +import com.johnsnowlabs.nlp.{Annotation, AnnotatorBuilder, EmbeddingsFinisher, Finisher} |
25 | 26 | import com.johnsnowlabs.tags.FastTest |
26 | 27 | import org.apache.spark.ml.Pipeline |
| 28 | +import org.apache.spark.sql.Row |
27 | 29 | import org.scalatest.flatspec.AnyFlatSpec |
28 | 30 |
|
29 | 31 | class ChunkEmbeddingsTestSpec extends AnyFlatSpec { |
@@ -266,4 +268,53 @@ class ChunkEmbeddingsTestSpec extends AnyFlatSpec { |
266 | 268 |
|
267 | 269 | } |
268 | 270 |
|
| 271 | + "ChunkEmbeddings" should "return chunk metadata at output" taggedAs FastTest in { |
| 272 | + import com.johnsnowlabs.nlp.AnnotationUtils._ |
| 273 | + val document = "Record: Bush Blue, ZIPCODE: XYZ84556222, phone: (911) 45 88".toRow() |
| 274 | + |
| 275 | + val chunks = Row( |
| 276 | + Seq( |
| 277 | + Annotation( |
| 278 | + CHUNK, |
| 279 | + 8, |
| 280 | + 16, |
| 281 | + "Bush Blue", |
| 282 | + Map("entity" -> "NAME", "sentence" -> "0", "chunk" -> "0", "confidence" -> "0.98")) |
| 283 | + .toRow(), |
| 284 | + Annotation( |
| 285 | + CHUNK, |
| 286 | + 48, |
| 287 | + 58, |
| 288 | + "(911) 45 88", |
| 289 | + Map("entity" -> "PHONE", "sentence" -> "0", "chunk" -> "1", "confidence" -> "1.0")) |
| 290 | + .toRow())) |
| 291 | + |
| 292 | + val df = createAnnotatorDataframe("sentence", DOCUMENT, document) |
| 293 | + .crossJoin(createAnnotatorDataframe("chunk", CHUNK, chunks)) |
| 294 | + |
| 295 | + val token = new Tokenizer() |
| 296 | + .setInputCols("sentence") |
| 297 | + .setOutputCol("token") |
| 298 | + |
| 299 | + val wordEmbeddings = WordEmbeddingsModel |
| 300 | + .pretrained() |
| 301 | + .setInputCols("sentence", "token") |
| 302 | + .setOutputCol("embeddings") |
| 303 | + |
| 304 | + val chunkEmbeddings = new ChunkEmbeddings() |
| 305 | + .setInputCols("chunk", "embeddings") |
| 306 | + .setOutputCol("chunk_embeddings") |
| 307 | + .setPoolingStrategy("AVERAGE") |
| 308 | + |
| 309 | + val pipeline = new Pipeline().setStages(Array(token, wordEmbeddings, chunkEmbeddings)) |
| 310 | + val result_df = pipeline.fit(df).transform(df) |
| 311 | + // result_df.selectExpr("explode(chunk_embeddings) as embeddings").show(false) |
| 312 | + val annotations = Annotation.collect(result_df, "chunk_embeddings").flatten |
| 313 | + assert(annotations.length == 2) |
| 314 | + assert(annotations(0).metadata("entity") == "NAME") |
| 315 | + assert(annotations(1).metadata("entity") == "PHONE") |
| 316 | + val expectedMetadataKeys = Set("entity", "sentence", "chunk", "confidence") |
| 317 | + assert(annotations.forall(anno => expectedMetadataKeys.forall(anno.metadata.contains))) |
| 318 | + } |
| 319 | + |
269 | 320 | } |
0 commit comments