diff --git a/build.sbt b/build.sbt index c1762b95c70e8d..8ddde39b15a3fd 100644 --- a/build.sbt +++ b/build.sbt @@ -73,7 +73,9 @@ lazy val utilDependencies = Seq( scratchpad exclude ("org.apache.logging.log4j", "log4j-api"), pdfBox, - flexmark) + flexmark, + tagSoup +) lazy val typedDependencyParserDependencies = Seq(junit) diff --git a/project/Dependencies.scala b/project/Dependencies.scala index 524225b7e1074b..da0039c6d52f03 100644 --- a/project/Dependencies.scala +++ b/project/Dependencies.scala @@ -153,5 +153,8 @@ object Dependencies { val flexmarkVersion = "0.61.34" val flexmark = "com.vladsch.flexmark" % "flexmark-all" % flexmarkVersion + + val tagSoupVersion = "1.2.1" + val tagSoup = "org.ccil.cowan.tagsoup" % "tagsoup" % tagSoupVersion /** ------- Dependencies end ------- */ } diff --git a/src/main/scala/com/johnsnowlabs/reader/XMLReader.scala b/src/main/scala/com/johnsnowlabs/reader/XMLReader.scala index 3665cdec2b6651..86a474cd6f9a48 100644 --- a/src/main/scala/com/johnsnowlabs/reader/XMLReader.scala +++ b/src/main/scala/com/johnsnowlabs/reader/XMLReader.scala @@ -20,9 +20,13 @@ import com.johnsnowlabs.nlp.util.io.ResourceHelper.validFile import com.johnsnowlabs.partition.util.PartitionHelper.datasetWithTextFile import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions.{col, udf} +import org.ccil.cowan.tagsoup.jaxp.SAXFactoryImpl +import org.xml.sax.InputSource +import java.io.StringReader import scala.collection.mutable import scala.collection.mutable.ListBuffer +import scala.xml.parsing.NoBindingFactoryAdapter import scala.xml.{Elem, Node, XML} /** Class to parse and read XML files. @@ -102,7 +106,9 @@ class XMLReader( }) def parseXml(xmlString: String): List[HTMLElement] = { - val xml = XML.loadString(xmlString) + val parser = new SAXFactoryImpl().newSAXParser() + val adapter = new NoBindingFactoryAdapter + val xml = adapter.loadXML(new InputSource(new StringReader(xmlString)), parser) val elements = ListBuffer[HTMLElement]() def traverse(node: Node, parentId: Option[String]): Unit = { @@ -128,10 +134,19 @@ class XMLReader( elements += HTMLElement(elementType, content, metadata) } + elem.attributes.asAttrMap.foreach { case (attrName, attrValue) => + val attrId = hash(tagName + attrName + attrValue) + val metadata = + mutable.Map("elementId" -> attrId, "parentId" -> elementId, "attribute" -> attrName) + if (xmlKeepTags) metadata += ("tag" -> tagName) + + elements += HTMLElement(ElementType.NARRATIVE_TEXT, attrValue, metadata) + } + // Traverse children elem.child.foreach(traverse(_, Some(elementId))) - case _ => // Ignore other types + case _ => // ignore } } diff --git a/src/test/scala/com/johnsnowlabs/reader/Reader2DocTest.scala b/src/test/scala/com/johnsnowlabs/reader/Reader2DocTest.scala index 866177c52aaba1..7bd1ef053f163d 100644 --- a/src/test/scala/com/johnsnowlabs/reader/Reader2DocTest.scala +++ b/src/test/scala/com/johnsnowlabs/reader/Reader2DocTest.scala @@ -318,7 +318,6 @@ class Reader2DocTest extends AnyFlatSpec with SparkSessionTest { assert(annotations.head.metadata("elementType") != ElementType.IMAGE) } } - it should "validate invalid paths" taggedAs SlowTest in { val reader2Doc = new Reader2Doc() @@ -349,4 +348,23 @@ class Reader2DocTest extends AnyFlatSpec with SparkSessionTest { assert(resultDf.filter(col("exception").isNotNull).count() >= 1) } + it should "parse attributes inside XML files" taggedAs FastTest in { + val reader2Doc = new Reader2Doc() + .setContentType("application/xml") + .setContentPath(s"$xmlDirectory/test.xml") + .setOutputCol("document") + + val pipeline = new Pipeline().setStages(Array(reader2Doc)) + + val pipelineModel = pipeline.fit(emptyDataSet) + val resultDf = pipelineModel.transform(emptyDataSet) + + val annotationsResult = AssertAnnotations.getActualResult(resultDf, "document") + val attributeElements = annotationsResult.flatMap { annotations => + annotations.filter(ann => ann.metadata.contains("attribute")) + } + + assert(attributeElements.length > 0, "Expected to find attribute elements in the XML content") + } + } diff --git a/src/test/scala/com/johnsnowlabs/reader/XMLReaderTest.scala b/src/test/scala/com/johnsnowlabs/reader/XMLReaderTest.scala index a75537803e61de..1a0bd3ed3688e1 100644 --- a/src/test/scala/com/johnsnowlabs/reader/XMLReaderTest.scala +++ b/src/test/scala/com/johnsnowlabs/reader/XMLReaderTest.scala @@ -40,4 +40,43 @@ class XMLReaderTest extends AnyFlatSpec { assert(noParentIdCount.count() > 0) } + it should "extract attributes as NARRATIVE_TEXT elements" taggedAs FastTest in { + val xml = + """ + | + |""".stripMargin + + val reader = new XMLReader(xmlKeepTags = true, onlyLeafNodes = true) + val elements = reader.parseXml(xml) + + val attrElements = elements.filter(_.elementType == ElementType.NARRATIVE_TEXT) + + assert(attrElements.nonEmpty, "Attributes should be extracted as NARRATIVE_TEXT") + + val codeAttrOpt = + attrElements.find(_.metadata.get("attribute").exists(_.equalsIgnoreCase("code"))) + assert(codeAttrOpt.isDefined, "Expected attribute 'code' was not found") + assert(codeAttrOpt.get.content == "ASSERTION") + + val statusAttrOpt = + attrElements.find(_.metadata.get("attribute").exists(_.equalsIgnoreCase("statusCode"))) + assert(statusAttrOpt.isDefined, "Expected attribute 'statusCode' was not found") + assert(statusAttrOpt.get.content == "completed") + } + + it should "link attribute elements to their parentId" taggedAs FastTest in { + val xml = + """ + | Content + |""".stripMargin + + val reader = new XMLReader(xmlKeepTags = true, onlyLeafNodes = true) + val elements = reader.parseXml(xml) + + val itemElem = elements.find(e => e.metadata.get("tag").contains("item")).get + val attrElems = elements.filter(_.metadata.contains("attribute")) + + assert(attrElems.forall(_.metadata("parentId") == itemElem.metadata("elementId"))) + } + }