Skip to content

Commit 9fc975d

Browse files
committed
[SPARKNLP-1292] Adding fault-tolerance support when reading malformed XML files
1 parent d4e84d5 commit 9fc975d

File tree

5 files changed

+81
-4
lines changed

5 files changed

+81
-4
lines changed

build.sbt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,9 @@ lazy val utilDependencies = Seq(
7373
scratchpad
7474
exclude ("org.apache.logging.log4j", "log4j-api"),
7575
pdfBox,
76-
flexmark)
76+
flexmark,
77+
tagSoup
78+
)
7779

7880
lazy val typedDependencyParserDependencies = Seq(junit)
7981

project/Dependencies.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,5 +153,8 @@ object Dependencies {
153153

154154
val flexmarkVersion = "0.61.34"
155155
val flexmark = "com.vladsch.flexmark" % "flexmark-all" % flexmarkVersion
156+
157+
val tagSoupVersion = "1.2.1"
158+
val tagSoup = "org.ccil.cowan.tagsoup" % "tagsoup" % tagSoupVersion
156159
/** ------- Dependencies end ------- */
157160
}

src/main/scala/com/johnsnowlabs/reader/XMLReader.scala

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,13 @@ import com.johnsnowlabs.nlp.util.io.ResourceHelper.validFile
2020
import com.johnsnowlabs.partition.util.PartitionHelper.datasetWithTextFile
2121
import org.apache.spark.sql.DataFrame
2222
import org.apache.spark.sql.functions.{col, udf}
23+
import org.ccil.cowan.tagsoup.jaxp.SAXFactoryImpl
24+
import org.xml.sax.InputSource
2325

26+
import java.io.StringReader
2427
import scala.collection.mutable
2528
import scala.collection.mutable.ListBuffer
29+
import scala.xml.parsing.NoBindingFactoryAdapter
2630
import scala.xml.{Elem, Node, XML}
2731

2832
/** Class to parse and read XML files.
@@ -102,7 +106,9 @@ class XMLReader(
102106
})
103107

104108
def parseXml(xmlString: String): List[HTMLElement] = {
105-
val xml = XML.loadString(xmlString)
109+
val parser = new SAXFactoryImpl().newSAXParser()
110+
val adapter = new NoBindingFactoryAdapter
111+
val xml = adapter.loadXML(new InputSource(new StringReader(xmlString)), parser)
106112
val elements = ListBuffer[HTMLElement]()
107113

108114
def traverse(node: Node, parentId: Option[String]): Unit = {
@@ -128,10 +134,19 @@ class XMLReader(
128134
elements += HTMLElement(elementType, content, metadata)
129135
}
130136

137+
elem.attributes.asAttrMap.foreach { case (attrName, attrValue) =>
138+
val attrId = hash(tagName + attrName + attrValue)
139+
val metadata =
140+
mutable.Map("elementId" -> attrId, "parentId" -> elementId, "attribute" -> attrName)
141+
if (xmlKeepTags) metadata += ("tag" -> tagName)
142+
143+
elements += HTMLElement(ElementType.NARRATIVE_TEXT, attrValue, metadata)
144+
}
145+
131146
// Traverse children
132147
elem.child.foreach(traverse(_, Some(elementId)))
133148

134-
case _ => // Ignore other types
149+
case _ => // ignore
135150
}
136151
}
137152

src/test/scala/com/johnsnowlabs/reader/Reader2DocTest.scala

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,6 @@ class Reader2DocTest extends AnyFlatSpec with SparkSessionTest {
318318
assert(annotations.head.metadata("elementType") != ElementType.IMAGE)
319319
}
320320
}
321-
322321
it should "validate invalid paths" taggedAs SlowTest in {
323322

324323
val reader2Doc = new Reader2Doc()
@@ -349,4 +348,23 @@ class Reader2DocTest extends AnyFlatSpec with SparkSessionTest {
349348
assert(resultDf.filter(col("exception").isNotNull).count() >= 1)
350349
}
351350

351+
it should "parse attributes inside XML files" taggedAs FastTest in {
352+
val reader2Doc = new Reader2Doc()
353+
.setContentType("application/xml")
354+
.setContentPath(s"$xmlDirectory/test.xml")
355+
.setOutputCol("document")
356+
357+
val pipeline = new Pipeline().setStages(Array(reader2Doc))
358+
359+
val pipelineModel = pipeline.fit(emptyDataSet)
360+
val resultDf = pipelineModel.transform(emptyDataSet)
361+
362+
val annotationsResult = AssertAnnotations.getActualResult(resultDf, "document")
363+
val attributeElements = annotationsResult.flatMap { annotations =>
364+
annotations.filter(ann => ann.metadata.contains("attribute"))
365+
}
366+
367+
assert(attributeElements.length > 0, "Expected to find attribute elements in the XML content")
368+
}
369+
352370
}

src/test/scala/com/johnsnowlabs/reader/XMLReaderTest.scala

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,4 +40,43 @@ class XMLReaderTest extends AnyFlatSpec {
4040
assert(noParentIdCount.count() > 0)
4141
}
4242

43+
it should "extract attributes as NARRATIVE_TEXT elements" taggedAs FastTest in {
44+
val xml =
45+
"""<root>
46+
| <observation code="ASSERTION" statusCode="completed"/>
47+
|</root>""".stripMargin
48+
49+
val reader = new XMLReader(xmlKeepTags = true, onlyLeafNodes = true)
50+
val elements = reader.parseXml(xml)
51+
52+
val attrElements = elements.filter(_.elementType == ElementType.NARRATIVE_TEXT)
53+
54+
assert(attrElements.nonEmpty, "Attributes should be extracted as NARRATIVE_TEXT")
55+
56+
val codeAttrOpt =
57+
attrElements.find(_.metadata.get("attribute").exists(_.equalsIgnoreCase("code")))
58+
assert(codeAttrOpt.isDefined, "Expected attribute 'code' was not found")
59+
assert(codeAttrOpt.get.content == "ASSERTION")
60+
61+
val statusAttrOpt =
62+
attrElements.find(_.metadata.get("attribute").exists(_.equalsIgnoreCase("statusCode")))
63+
assert(statusAttrOpt.isDefined, "Expected attribute 'statusCode' was not found")
64+
assert(statusAttrOpt.get.content == "completed")
65+
}
66+
67+
it should "link attribute elements to their parentId" taggedAs FastTest in {
68+
val xml =
69+
"""<root>
70+
| <item id="123" class="test">Content</item>
71+
|</root>""".stripMargin
72+
73+
val reader = new XMLReader(xmlKeepTags = true, onlyLeafNodes = true)
74+
val elements = reader.parseXml(xml)
75+
76+
val itemElem = elements.find(e => e.metadata.get("tag").contains("item")).get
77+
val attrElems = elements.filter(_.metadata.contains("attribute"))
78+
79+
assert(attrElems.forall(_.metadata("parentId") == itemElem.metadata("elementId")))
80+
}
81+
4382
}

0 commit comments

Comments
 (0)