diff --git a/build.sbt b/build.sbt
index c1762b95c70e8d..8ddde39b15a3fd 100644
--- a/build.sbt
+++ b/build.sbt
@@ -73,7 +73,9 @@ lazy val utilDependencies = Seq(
scratchpad
exclude ("org.apache.logging.log4j", "log4j-api"),
pdfBox,
- flexmark)
+ flexmark,
+ tagSoup
+)
lazy val typedDependencyParserDependencies = Seq(junit)
diff --git a/project/Dependencies.scala b/project/Dependencies.scala
index 524225b7e1074b..da0039c6d52f03 100644
--- a/project/Dependencies.scala
+++ b/project/Dependencies.scala
@@ -153,5 +153,8 @@ object Dependencies {
val flexmarkVersion = "0.61.34"
val flexmark = "com.vladsch.flexmark" % "flexmark-all" % flexmarkVersion
+
+ val tagSoupVersion = "1.2.1"
+ val tagSoup = "org.ccil.cowan.tagsoup" % "tagsoup" % tagSoupVersion
/** ------- Dependencies end ------- */
}
diff --git a/src/main/scala/com/johnsnowlabs/reader/XMLReader.scala b/src/main/scala/com/johnsnowlabs/reader/XMLReader.scala
index 3665cdec2b6651..86a474cd6f9a48 100644
--- a/src/main/scala/com/johnsnowlabs/reader/XMLReader.scala
+++ b/src/main/scala/com/johnsnowlabs/reader/XMLReader.scala
@@ -20,9 +20,13 @@ import com.johnsnowlabs.nlp.util.io.ResourceHelper.validFile
import com.johnsnowlabs.partition.util.PartitionHelper.datasetWithTextFile
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions.{col, udf}
+import org.ccil.cowan.tagsoup.jaxp.SAXFactoryImpl
+import org.xml.sax.InputSource
+import java.io.StringReader
import scala.collection.mutable
import scala.collection.mutable.ListBuffer
+import scala.xml.parsing.NoBindingFactoryAdapter
import scala.xml.{Elem, Node, XML}
/** Class to parse and read XML files.
@@ -102,7 +106,9 @@ class XMLReader(
})
def parseXml(xmlString: String): List[HTMLElement] = {
- val xml = XML.loadString(xmlString)
+ val parser = new SAXFactoryImpl().newSAXParser()
+ val adapter = new NoBindingFactoryAdapter
+ val xml = adapter.loadXML(new InputSource(new StringReader(xmlString)), parser)
val elements = ListBuffer[HTMLElement]()
def traverse(node: Node, parentId: Option[String]): Unit = {
@@ -128,10 +134,19 @@ class XMLReader(
elements += HTMLElement(elementType, content, metadata)
}
+ elem.attributes.asAttrMap.foreach { case (attrName, attrValue) =>
+ val attrId = hash(tagName + attrName + attrValue)
+ val metadata =
+ mutable.Map("elementId" -> attrId, "parentId" -> elementId, "attribute" -> attrName)
+ if (xmlKeepTags) metadata += ("tag" -> tagName)
+
+ elements += HTMLElement(ElementType.NARRATIVE_TEXT, attrValue, metadata)
+ }
+
// Traverse children
elem.child.foreach(traverse(_, Some(elementId)))
- case _ => // Ignore other types
+ case _ => // ignore
}
}
diff --git a/src/test/scala/com/johnsnowlabs/reader/Reader2DocTest.scala b/src/test/scala/com/johnsnowlabs/reader/Reader2DocTest.scala
index 866177c52aaba1..7bd1ef053f163d 100644
--- a/src/test/scala/com/johnsnowlabs/reader/Reader2DocTest.scala
+++ b/src/test/scala/com/johnsnowlabs/reader/Reader2DocTest.scala
@@ -318,7 +318,6 @@ class Reader2DocTest extends AnyFlatSpec with SparkSessionTest {
assert(annotations.head.metadata("elementType") != ElementType.IMAGE)
}
}
-
it should "validate invalid paths" taggedAs SlowTest in {
val reader2Doc = new Reader2Doc()
@@ -349,4 +348,23 @@ class Reader2DocTest extends AnyFlatSpec with SparkSessionTest {
assert(resultDf.filter(col("exception").isNotNull).count() >= 1)
}
+ it should "parse attributes inside XML files" taggedAs FastTest in {
+ val reader2Doc = new Reader2Doc()
+ .setContentType("application/xml")
+ .setContentPath(s"$xmlDirectory/test.xml")
+ .setOutputCol("document")
+
+ val pipeline = new Pipeline().setStages(Array(reader2Doc))
+
+ val pipelineModel = pipeline.fit(emptyDataSet)
+ val resultDf = pipelineModel.transform(emptyDataSet)
+
+ val annotationsResult = AssertAnnotations.getActualResult(resultDf, "document")
+ val attributeElements = annotationsResult.flatMap { annotations =>
+ annotations.filter(ann => ann.metadata.contains("attribute"))
+ }
+
+ assert(attributeElements.length > 0, "Expected to find attribute elements in the XML content")
+ }
+
}
diff --git a/src/test/scala/com/johnsnowlabs/reader/XMLReaderTest.scala b/src/test/scala/com/johnsnowlabs/reader/XMLReaderTest.scala
index a75537803e61de..1a0bd3ed3688e1 100644
--- a/src/test/scala/com/johnsnowlabs/reader/XMLReaderTest.scala
+++ b/src/test/scala/com/johnsnowlabs/reader/XMLReaderTest.scala
@@ -40,4 +40,43 @@ class XMLReaderTest extends AnyFlatSpec {
assert(noParentIdCount.count() > 0)
}
+ it should "extract attributes as NARRATIVE_TEXT elements" taggedAs FastTest in {
+ val xml =
+ """
+ |
+ |""".stripMargin
+
+ val reader = new XMLReader(xmlKeepTags = true, onlyLeafNodes = true)
+ val elements = reader.parseXml(xml)
+
+ val attrElements = elements.filter(_.elementType == ElementType.NARRATIVE_TEXT)
+
+ assert(attrElements.nonEmpty, "Attributes should be extracted as NARRATIVE_TEXT")
+
+ val codeAttrOpt =
+ attrElements.find(_.metadata.get("attribute").exists(_.equalsIgnoreCase("code")))
+ assert(codeAttrOpt.isDefined, "Expected attribute 'code' was not found")
+ assert(codeAttrOpt.get.content == "ASSERTION")
+
+ val statusAttrOpt =
+ attrElements.find(_.metadata.get("attribute").exists(_.equalsIgnoreCase("statusCode")))
+ assert(statusAttrOpt.isDefined, "Expected attribute 'statusCode' was not found")
+ assert(statusAttrOpt.get.content == "completed")
+ }
+
+ it should "link attribute elements to their parentId" taggedAs FastTest in {
+ val xml =
+ """
+ | - Content
+ |""".stripMargin
+
+ val reader = new XMLReader(xmlKeepTags = true, onlyLeafNodes = true)
+ val elements = reader.parseXml(xml)
+
+ val itemElem = elements.find(e => e.metadata.get("tag").contains("item")).get
+ val attrElems = elements.filter(_.metadata.contains("attribute"))
+
+ assert(attrElems.forall(_.metadata("parentId") == itemElem.metadata("elementId")))
+ }
+
}