Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ class ChunkEmbeddings(override val uid: String)
begin = chunk.begin,
end = chunk.end,
result = chunk.result,
metadata = Map(
metadata = chunk.metadata ++ Map(
"sentence" -> sentenceIdx.toString,
"chunk" -> chunkIdx.toString,
"token" -> chunk.result,
Expand Down
46 changes: 46 additions & 0 deletions src/test/scala/com/johnsnowlabs/nlp/AnnotationUtils.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package com.johnsnowlabs.nlp

import com.johnsnowlabs.nlp.AnnotatorType.DOCUMENT
import org.apache.spark.sql.types.{MetadataBuilder, StructField, StructType}
import org.apache.spark.sql.{DataFrame, Row}

object AnnotationUtils {

private lazy val spark = SparkAccessor.spark

implicit class AnnotationRow(annotation: Annotation) {

def toRow(): Row = {
Row(
annotation.annotatorType,
annotation.begin,
annotation.end,
annotation.result,
annotation.metadata,
annotation.embeddings)
}
}

implicit class DocumentRow(s: String) {
def toRow(metadata: Map[String, String] = Map("sentence" -> "0")): Row = {
Row(Seq(Annotation(DOCUMENT, 0, s.length, s, metadata).toRow()))
}
}

/** Create a DataFrame with the given column name, annotator type and annotations row Output
* column will be compatible with the Spark NLP annotators
*/
def createAnnotatorDataframe(
columnName: String,
annotatorType: String,
annotationsRow: Row): DataFrame = {
val metadataBuilder: MetadataBuilder = new MetadataBuilder()
metadataBuilder.putString("annotatorType", annotatorType)
val documentField =
StructField(columnName, Annotation.arrayType, nullable = false, metadataBuilder.build)
val struct = StructType(Array(documentField))
val rdd = spark.sparkContext.parallelize(Seq(annotationsRow))
spark.createDataFrame(rdd, struct)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,16 @@

package com.johnsnowlabs.nlp.embeddings

import com.johnsnowlabs.nlp.AnnotatorType.{CHUNK, DOCUMENT}
import com.johnsnowlabs.nlp.annotator.{Chunker, PerceptronModel}
import com.johnsnowlabs.nlp.annotators.sbd.pragmatic.SentenceDetector
import com.johnsnowlabs.nlp.annotators.{NGramGenerator, StopWordsCleaner, Tokenizer}
import com.johnsnowlabs.nlp.base.DocumentAssembler
import com.johnsnowlabs.nlp.util.io.ResourceHelper
import com.johnsnowlabs.nlp.{AnnotatorBuilder, EmbeddingsFinisher, Finisher}
import com.johnsnowlabs.nlp.{Annotation, AnnotatorBuilder, EmbeddingsFinisher, Finisher}
import com.johnsnowlabs.tags.FastTest
import org.apache.spark.ml.Pipeline
import org.apache.spark.sql.Row
import org.scalatest.flatspec.AnyFlatSpec

class ChunkEmbeddingsTestSpec extends AnyFlatSpec {
Expand Down Expand Up @@ -266,4 +268,53 @@ class ChunkEmbeddingsTestSpec extends AnyFlatSpec {

}

"ChunkEmbeddings" should "return chunk metadata at output" taggedAs FastTest in {
import com.johnsnowlabs.nlp.AnnotationUtils._
val document = "Record: Bush Blue, ZIPCODE: XYZ84556222, phone: (911) 45 88".toRow()

val chunks = Row(
Seq(
Annotation(
CHUNK,
8,
16,
"Bush Blue",
Map("entity" -> "NAME", "sentence" -> "0", "chunk" -> "0", "confidence" -> "0.98"))
.toRow(),
Annotation(
CHUNK,
48,
58,
"(911) 45 88",
Map("entity" -> "PHONE", "sentence" -> "0", "chunk" -> "1", "confidence" -> "1.0"))
.toRow()))

val df = createAnnotatorDataframe("sentence", DOCUMENT, document)
.crossJoin(createAnnotatorDataframe("chunk", CHUNK, chunks))

val token = new Tokenizer()
.setInputCols("sentence")
.setOutputCol("token")

val wordEmbeddings = WordEmbeddingsModel
.pretrained()
.setInputCols("sentence", "token")
.setOutputCol("embeddings")

val chunkEmbeddings = new ChunkEmbeddings()
.setInputCols("chunk", "embeddings")
.setOutputCol("chunk_embeddings")
.setPoolingStrategy("AVERAGE")

val pipeline = new Pipeline().setStages(Array(token, wordEmbeddings, chunkEmbeddings))
val result_df = pipeline.fit(df).transform(df)
// result_df.selectExpr("explode(chunk_embeddings) as embeddings").show(false)
val annotations = Annotation.collect(result_df, "chunk_embeddings").flatten
assert(annotations.length == 2)
assert(annotations(0).metadata("entity") == "NAME")
assert(annotations(1).metadata("entity") == "PHONE")
val expectedMetadataKeys = Set("entity", "sentence", "chunk", "confidence")
assert(annotations.forall(anno => expectedMetadataKeys.forall(anno.metadata.contains)))
}

}
Loading