Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 144 additions & 1 deletion src/main/scala/com/johnsnowlabs/nlp/ParamsAndFeaturesReadable.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,14 @@

package com.johnsnowlabs.nlp

import com.johnsnowlabs.nlp.LegacyMetadataSupport.ParamsReflection
import org.apache.hadoop.fs.Path
import org.apache.spark.internal.Logging
import org.apache.spark.ml.param.Params
import org.apache.spark.ml.util.{DefaultParamsReadable, MLReader}
import org.apache.spark.sql.SparkSession
import org.json4s.jackson.JsonMethods.{compact, parse, render}
import org.json4s.{DefaultFormats, JNothing, JNull, JObject, JValue}

import scala.collection.mutable.ArrayBuffer
import scala.util.{Failure, Success, Try}
Expand All @@ -29,7 +35,15 @@ class FeaturesReader[T <: HasFeatures](

override def load(path: String): T = {

val instance = baseReader.load(path)
val instance =
try {
// Let Spark's own loader handle modern bundles.
baseReader.load(path)
} catch {
case e: NoSuchElementException if isMissingParamError(e) =>
// Reconstruct legacy models that referenced params removed in newer releases.
loadWithLegacyParams(path)
}

for (feature <- instance.features) {
val value = feature.deserialize(sparkSession, path, feature.name)
Expand All @@ -40,6 +54,59 @@ class FeaturesReader[T <: HasFeatures](

instance
}

private def isMissingParamError(e: NoSuchElementException): Boolean = {
val msg = Option(e.getMessage).getOrElse("")
msg.contains("Param")
}

private def loadWithLegacyParams(path: String): T = {
val metadata = LegacyMetadataSupport.load(path, sparkSession)
val cls = Class.forName(metadata.className)
val ctor = cls.getConstructor(classOf[String])
val instance = ctor.newInstance(metadata.uid).asInstanceOf[Params]
setParamsIgnoringUnknown(instance, metadata)
instance.asInstanceOf[T]
}

private def setParamsIgnoringUnknown(
instance: Params,
metadata: LegacyMetadataSupport.Metadata): Unit = {
// Replay active params; skip mismatches so legacy bundles still come back.
assignParams(instance, metadata.params, isDefault = false, metadata)

val hasDefaultSection = metadata.defaultParams != JNothing && metadata.defaultParams != JNull
if (hasDefaultSection) {
// If the metadata carried defaults, restore only those that still exists.
assignParams(instance, metadata.defaultParams, isDefault = true, metadata)
}
}

private def assignParams(
instance: Params,
jsonParams: JValue,
isDefault: Boolean,
metadata: LegacyMetadataSupport.Metadata): Unit = {
jsonParams match {
case JObject(pairs) =>
pairs.foreach { case (paramName, jsonValue) =>
if (instance.hasParam(paramName)) {
val param = instance.getParam(paramName)
val value = param.jsonDecode(compact(render(jsonValue)))
if (isDefault) {
// Spark keeps setDefault protected; call it via reflection to restore legacy defaults.
ParamsReflection.setDefault(instance, param, value)
} else {
instance.set(param, value)
}
}
}
case JNothing | JNull =>
case other =>
throw new IllegalArgumentException(
s"Cannot recognize JSON metadata when loading legacy params for ${metadata.className}: $other")
}
}
}

trait ParamsAndFeaturesReadable[T <: HasFeatures] extends DefaultParamsReadable[T] {
Expand Down Expand Up @@ -137,3 +204,79 @@ trait ParamsAndFeaturesFallbackReadable[T <: HasFeatures] extends ParamsAndFeatu

override def read: MLReader[T] = new FeaturesFallbackReader(super.read, onRead, fallbackLoad)
}

// Minimal metadata parser + helper utilities for replaying legacy params.
protected object LegacyMetadataSupport {

object ParamsReflection {
private val setDefaultMethod = {
val maybeMethod = classOf[Params].getDeclaredMethods.find { method =>
method.getName == "setDefault" && method.getParameterCount == 2
}

maybeMethod match {
case Some(method) =>
method.setAccessible(true)
method
case None =>
throw new NoSuchMethodException("Params.setDefault(Param, value) not found via reflection")
}
}

def setDefault[T](
params: Params,
param: org.apache.spark.ml.param.Param[T],
value: T): Unit = {
setDefaultMethod.invoke(params, param, toAnyRef(value))
}

// Mirror JVM boxing rules so reflection can call the protected method safely.
private def toAnyRef(value: Any): AnyRef = {
if (value == null) {
null
} else {
value match {
case v: AnyRef => v
case v: Boolean => java.lang.Boolean.valueOf(v)
case v: Byte => java.lang.Byte.valueOf(v)
case v: Short => java.lang.Short.valueOf(v)
case v: Int => java.lang.Integer.valueOf(v)
case v: Long => java.lang.Long.valueOf(v)
case v: Float => java.lang.Float.valueOf(v)
case v: Double => java.lang.Double.valueOf(v)
case v: Char => java.lang.Character.valueOf(v)
case other =>
throw new IllegalArgumentException(
s"Unsupported default value type ${other.getClass}")
}
}
}
}

case class Metadata(
className: String,
uid: String,
sparkVersion: String,
params: JValue,
defaultParams: JValue,
metadataJson: String)

def load(path: String, spark: SparkSession): Metadata = {
val metadataPath = new Path(path, "metadata").toString
val metadataStr = spark.sparkContext.textFile(metadataPath, 1).first()
parseMetadata(metadataStr)
}

private def parseMetadata(metadataStr: String): Metadata = {
val metadata = parse(metadataStr)
implicit val format: DefaultFormats.type = DefaultFormats

val className = (metadata \ "class").extract[String]
val uid = (metadata \ "uid").extract[String]
val sparkVersion = (metadata \ "sparkVersion").extractOpt[String].getOrElse("0.0")
val params = metadata \ "paramMap"
val defaultParams = metadata \ "defaultParamMap"

Metadata(className, uid, sparkVersion, params, defaultParams, metadataStr)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import com.johnsnowlabs.nlp.annotators.SparkSessionTest
import com.johnsnowlabs.nlp.annotators.er.EntityRulerFixture._
import com.johnsnowlabs.nlp.base.LightPipeline
import com.johnsnowlabs.nlp.util.io.{ExternalResource, ReadAs}
import com.johnsnowlabs.tags.FastTest
import com.johnsnowlabs.tags.{FastTest, SlowTest}
import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.scalatest.flatspec.AnyFlatSpec

Expand Down Expand Up @@ -850,4 +850,25 @@ class EntityRulerTest extends AnyFlatSpec with SparkSessionTest {
entityRulerPipeline
}

it should "serialize EntityRulerModel" taggedAs SlowTest in {
//Should be run with Java 8 and Scala 2.12
val entityRuler = new EntityRulerApproach()
.setInputCols("document", "token")
.setOutputCol("entities")
.setPatternsResource("src/test/resources/entity-ruler/keywords_only.json", ReadAs.TEXT)
val entityRulerModel = entityRuler.fit(emptyDataSet)

entityRulerModel.write.overwrite().save("./tmp_entity_ruler_model_java8_scala2_12")
}

it should "deserialize EntityRulerModel" taggedAs SlowTest in {
val textDataSet = Seq(text1).toDS.toDF("text")
val loadedEntityRulerModel = EntityRulerModel.load("./tmp_entity_ruler_model_java8_scala2_12")

val pipeline =
new Pipeline().setStages(Array(documentAssembler, tokenizer, loadedEntityRulerModel))
val resultDf = pipeline.fit(emptyDataSet).transform(textDataSet)
resultDf.select("entities").show(truncate = false)
}

}