Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,9 @@ lazy val utilDependencies = Seq(
scratchpad
exclude ("org.apache.logging.log4j", "log4j-api"),
pdfBox,
flexmark)
flexmark,
tagSoup
)

lazy val typedDependencyParserDependencies = Seq(junit)

Expand Down
3 changes: 3 additions & 0 deletions project/Dependencies.scala
Original file line number Diff line number Diff line change
Expand Up @@ -153,5 +153,8 @@ object Dependencies {

val flexmarkVersion = "0.61.34"
val flexmark = "com.vladsch.flexmark" % "flexmark-all" % flexmarkVersion

val tagSoupVersion = "1.2.1"
val tagSoup = "org.ccil.cowan.tagsoup" % "tagsoup" % tagSoupVersion
/** ------- Dependencies end ------- */
}
19 changes: 17 additions & 2 deletions src/main/scala/com/johnsnowlabs/reader/XMLReader.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,13 @@ import com.johnsnowlabs.nlp.util.io.ResourceHelper.validFile
import com.johnsnowlabs.partition.util.PartitionHelper.datasetWithTextFile
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions.{col, udf}
import org.ccil.cowan.tagsoup.jaxp.SAXFactoryImpl
import org.xml.sax.InputSource

import java.io.StringReader
import scala.collection.mutable
import scala.collection.mutable.ListBuffer
import scala.xml.parsing.NoBindingFactoryAdapter
import scala.xml.{Elem, Node, XML}

/** Class to parse and read XML files.
Expand Down Expand Up @@ -102,7 +106,9 @@ class XMLReader(
})

def parseXml(xmlString: String): List[HTMLElement] = {
val xml = XML.loadString(xmlString)
val parser = new SAXFactoryImpl().newSAXParser()
val adapter = new NoBindingFactoryAdapter
val xml = adapter.loadXML(new InputSource(new StringReader(xmlString)), parser)
val elements = ListBuffer[HTMLElement]()

def traverse(node: Node, parentId: Option[String]): Unit = {
Expand All @@ -128,10 +134,19 @@ class XMLReader(
elements += HTMLElement(elementType, content, metadata)
}

elem.attributes.asAttrMap.foreach { case (attrName, attrValue) =>
val attrId = hash(tagName + attrName + attrValue)
val metadata =
mutable.Map("elementId" -> attrId, "parentId" -> elementId, "attribute" -> attrName)
if (xmlKeepTags) metadata += ("tag" -> tagName)

elements += HTMLElement(ElementType.NARRATIVE_TEXT, attrValue, metadata)
}

// Traverse children
elem.child.foreach(traverse(_, Some(elementId)))

case _ => // Ignore other types
case _ => // ignore
}
}

Expand Down
20 changes: 19 additions & 1 deletion src/test/scala/com/johnsnowlabs/reader/Reader2DocTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,6 @@ class Reader2DocTest extends AnyFlatSpec with SparkSessionTest {
assert(annotations.head.metadata("elementType") != ElementType.IMAGE)
}
}

it should "validate invalid paths" taggedAs SlowTest in {

val reader2Doc = new Reader2Doc()
Expand Down Expand Up @@ -349,4 +348,23 @@ class Reader2DocTest extends AnyFlatSpec with SparkSessionTest {
assert(resultDf.filter(col("exception").isNotNull).count() >= 1)
}

it should "parse attributes inside XML files" taggedAs FastTest in {
val reader2Doc = new Reader2Doc()
.setContentType("application/xml")
.setContentPath(s"$xmlDirectory/test.xml")
.setOutputCol("document")

val pipeline = new Pipeline().setStages(Array(reader2Doc))

val pipelineModel = pipeline.fit(emptyDataSet)
val resultDf = pipelineModel.transform(emptyDataSet)

val annotationsResult = AssertAnnotations.getActualResult(resultDf, "document")
val attributeElements = annotationsResult.flatMap { annotations =>
annotations.filter(ann => ann.metadata.contains("attribute"))
}

assert(attributeElements.length > 0, "Expected to find attribute elements in the XML content")
}

}
39 changes: 39 additions & 0 deletions src/test/scala/com/johnsnowlabs/reader/XMLReaderTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,43 @@ class XMLReaderTest extends AnyFlatSpec {
assert(noParentIdCount.count() > 0)
}

it should "extract attributes as NARRATIVE_TEXT elements" taggedAs FastTest in {
val xml =
"""<root>
| <observation code="ASSERTION" statusCode="completed"/>
|</root>""".stripMargin

val reader = new XMLReader(xmlKeepTags = true, onlyLeafNodes = true)
val elements = reader.parseXml(xml)

val attrElements = elements.filter(_.elementType == ElementType.NARRATIVE_TEXT)

assert(attrElements.nonEmpty, "Attributes should be extracted as NARRATIVE_TEXT")

val codeAttrOpt =
attrElements.find(_.metadata.get("attribute").exists(_.equalsIgnoreCase("code")))
assert(codeAttrOpt.isDefined, "Expected attribute 'code' was not found")
assert(codeAttrOpt.get.content == "ASSERTION")

val statusAttrOpt =
attrElements.find(_.metadata.get("attribute").exists(_.equalsIgnoreCase("statusCode")))
assert(statusAttrOpt.isDefined, "Expected attribute 'statusCode' was not found")
assert(statusAttrOpt.get.content == "completed")
}

it should "link attribute elements to their parentId" taggedAs FastTest in {
val xml =
"""<root>
| <item id="123" class="test">Content</item>
|</root>""".stripMargin

val reader = new XMLReader(xmlKeepTags = true, onlyLeafNodes = true)
val elements = reader.parseXml(xml)

val itemElem = elements.find(e => e.metadata.get("tag").contains("item")).get
val attrElems = elements.filter(_.metadata.contains("attribute"))

assert(attrElems.forall(_.metadata("parentId") == itemElem.metadata("elementId")))
}

}