Skip to content

Commit 21e651d

Browse files
committed
[SPARKNLP-1291] Adding support fort input string column on readers
1 parent d4e84d5 commit 21e651d

File tree

10 files changed

+335
-37
lines changed

10 files changed

+335
-37
lines changed

python/sparknlp/partition/partition_properties.py

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,40 @@
1818

1919
class HasReaderProperties(Params):
2020

21+
inputCol = Param(
22+
Params._dummy(),
23+
"inputCol",
24+
"input column name",
25+
typeConverter=TypeConverters.toString
26+
)
27+
28+
def setInputCol(self, value):
29+
"""Sets input column name.
30+
31+
Parameters
32+
----------
33+
value : str
34+
Name of the Input Column
35+
"""
36+
return self._set(inputCol=value)
37+
2138
outputCol = Param(
2239
Params._dummy(),
2340
"outputCol",
2441
"output column name",
2542
typeConverter=TypeConverters.toString
2643
)
2744

45+
def setOutputCol(self, value):
46+
"""Sets output column name.
47+
48+
Parameters
49+
----------
50+
value : str
51+
Name of the Output Column
52+
"""
53+
return self._set(outputCol=value)
54+
2855
contentPath = Param(
2956
Params._dummy(),
3057
"contentPath",
@@ -683,13 +710,3 @@ def setReadAsImage(self, value: bool):
683710
True to read as images, False otherwise.
684711
"""
685712
return self._set(readAsImage=value)
686-
687-
def setOutputCol(self, value):
688-
"""Sets output column name.
689-
690-
Parameters
691-
----------
692-
value : str
693-
Name of the Output Column
694-
"""
695-
return self._set(outputCol=value)

python/test/reader/reader2doc_test.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,4 +111,24 @@ def runTest(self):
111111

112112
result_df = model.transform(self.empty_df)
113113

114+
self.assertTrue(result_df.select("document").count() > 0)
115+
116+
@pytest.mark.fast
117+
class Reader2DocTestInputColumn(unittest.TestCase):
118+
119+
def setUp(self):
120+
spark = SparkContextForTest.spark
121+
content = "<html><head><title>Test<title><body><p>Unclosed tag"
122+
self.html_df = spark.createDataFrame([(1, content)], ["id", "html"])
123+
124+
def runTest(self):
125+
reader2doc = Reader2Doc() \
126+
.setInputCol("html") \
127+
.setOutputCol("document")
128+
129+
pipeline = Pipeline(stages=[reader2doc])
130+
model = pipeline.fit(self.html_df)
131+
132+
result_df = model.transform(self.html_df)
133+
114134
self.assertTrue(result_df.select("document").count() > 0)

python/test/reader/reader2table_test.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@ def runTest(self):
4040
model = pipeline.fit(self.empty_df)
4141

4242
result_df = model.transform(self.empty_df)
43-
result_df.show(truncate=False)
4443

4544
self.assertTrue(result_df.select("document").count() > 0)
4645

@@ -60,4 +59,35 @@ def runTest(self):
6059

6160
result_df = model.transform(self.empty_df)
6261

63-
self.assertTrue(result_df.select("document").count() > 1)
62+
self.assertTrue(result_df.select("document").count() > 1)
63+
64+
@pytest.mark.fast
65+
class Reader2TableInputColTest(unittest.TestCase):
66+
67+
def setUp(self):
68+
content = """
69+
<html>
70+
<body>
71+
<table>
72+
<tr>
73+
<td>Hello World</td>
74+
</tr>
75+
</table>
76+
</body>
77+
</html>
78+
"""
79+
spark = SparkContextForTest.spark
80+
self.html_df = spark.createDataFrame([(1, content)], ["id", "html"])
81+
82+
def runTest(self):
83+
reader2table = Reader2Table() \
84+
.setInputCol("html") \
85+
.setContentType("text/html") \
86+
.setOutputCol("document")
87+
88+
pipeline = Pipeline(stages=[reader2table])
89+
model = pipeline.fit(self.html_df)
90+
91+
result_df = model.transform(self.html_df)
92+
93+
self.assertTrue(result_df.select("document").count() > 0)

src/main/scala/com/johnsnowlabs/partition/HasReaderProperties.scala

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,13 @@ import org.apache.spark.ml.param.{BooleanParam, Param}
1919

2020
trait HasReaderProperties extends HasHTMLReaderProperties {
2121

22+
protected final val inputCol: Param[String] =
23+
new Param(this, "inputCol", "input column to process")
24+
25+
final def setInputCol(value: String): this.type = set(inputCol, value)
26+
27+
final def getInputCol: String = $(inputCol)
28+
2229
val contentPath = new Param[String](this, "contentPath", "Path to the content source")
2330

2431
def setContentPath(value: String): this.type = set(contentPath, value)
@@ -75,6 +82,7 @@ trait HasReaderProperties extends HasHTMLReaderProperties {
7582
titleFontSize -> 9,
7683
inferTableStructure -> false,
7784
includePageBreaks -> false,
78-
ignoreExceptions -> true)
85+
ignoreExceptions -> true,
86+
inputCol -> "")
7987

8088
}

src/main/scala/com/johnsnowlabs/reader/HasReaderContent.scala

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,12 +104,11 @@ trait HasReaderContent extends HasReaderProperties {
104104
}
105105
}
106106

107-
def partitionContent(
107+
private def partitionContentFromPath(
108108
partition: Partition,
109109
contentPath: String,
110110
isText: Boolean,
111111
dataset: Dataset[_]): DataFrame = {
112-
113112
val ext = contentPath.split("\\.").lastOption.getOrElse("").toLowerCase
114113
if (! $(ignoreExceptions) && !supportedTypes.contains(ext)) {
115114
return buildErrorDataFrame(dataset, contentPath, ext)
@@ -148,6 +147,38 @@ trait HasReaderContent extends HasReaderProperties {
148147
} else partitionDf
149148
}
150149

150+
def partitionContent(
151+
partition: Partition,
152+
contentPath: String,
153+
isText: Boolean,
154+
dataset: Dataset[_]): DataFrame = {
155+
156+
val partitionDf =
157+
if (getInputCol != null && getInputCol.nonEmpty) {
158+
partitionContentFromDataFrame(partition, dataset, getInputCol)
159+
} else {
160+
partitionContentFromPath(partition, contentPath, isText, dataset)
161+
}
162+
163+
if ($(ignoreExceptions)) {
164+
partitionDf.filter(col("exception").isNull)
165+
} else partitionDf
166+
}
167+
168+
/** Partition content when it is already present in a dataset column. */
169+
private def partitionContentFromDataFrame(
170+
partition: Partition,
171+
dataset: Dataset[_],
172+
inputCol: String): DataFrame = {
173+
val partitionUDF =
174+
udf((text: String) => partition.partitionStringContent(text, $(this.headers).asJava))
175+
176+
dataset
177+
.withColumn(partition.getOutputColumn, partitionUDF(col(inputCol)))
178+
.withColumn("fileName", lit(null: String))
179+
.withColumn("exception", lit(null: String))
180+
}
181+
151182
val getFileName: UserDefinedFunction = udf { path: String =>
152183
if (path != null) path.split("/").last else ""
153184
}
@@ -166,4 +197,8 @@ trait HasReaderContent extends HasReaderProperties {
166197
dataset.sparkSession.createDataFrame(emptyRDD, schema)
167198
}
168199

200+
def getContentType: String = {
201+
if ($(contentType).trim.isEmpty && getInputCol.nonEmpty) "text/plain" else $(contentType)
202+
}
203+
169204
}

src/main/scala/com/johnsnowlabs/reader/Reader2Doc.scala

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -127,13 +127,13 @@ class Reader2Doc(override val uid: String)
127127

128128
override def transform(dataset: Dataset[_]): DataFrame = {
129129
validateRequiredParameters()
130-
val structuredDf = if ($(contentType).trim.isEmpty) {
130+
val structuredDf = if ($(contentType).trim.isEmpty && getInputCol.trim.isEmpty) {
131131
val partitionParams = Map(
132132
"inferTableStructure" -> $(inferTableStructure).toString,
133133
"outputFormat" -> $(outputFormat))
134134
partitionMixedContent(dataset, $(contentPath), partitionParams)
135135
} else {
136-
partitionContent(partitionBuilder, $(contentPath), isStringContent($(contentType)), dataset)
136+
partitionContent(partitionBuilder, $(contentPath), isStringContent(getContentType), dataset)
137137
}
138138
if (!structuredDf.isEmpty) {
139139
val annotatedDf = structuredDf
@@ -149,7 +149,7 @@ class Reader2Doc(override val uid: String)
149149

150150
protected def partitionBuilder: Partition = {
151151
val params = Map(
152-
"contentType" -> $(contentType),
152+
"contentType" -> getContentType,
153153
"storeContent" -> $(storeContent).toString,
154154
"titleFontSize" -> $(titleFontSize).toString,
155155
"inferTableStructure" -> $(inferTableStructure).toString,
@@ -186,15 +186,16 @@ class Reader2Doc(override val uid: String)
186186
}
187187

188188
protected def validateRequiredParameters(): Unit = {
189-
require(
190-
$(contentPath) != null && $(contentPath).trim.nonEmpty,
191-
"contentPath must be set and not empty")
189+
val hasContentPath = $(contentPath) != null && $(contentPath).trim.nonEmpty
190+
if (hasContentPath) {
191+
require(
192+
ResourceHelper.validFile($(contentPath)),
193+
"contentPath must point to a valid file or directory")
194+
}
195+
192196
require(
193197
$(outputFormat) == "plain-text",
194198
"Only 'plain-text' outputFormat is supported for this operation.")
195-
require(
196-
ResourceHelper.validFile($(contentPath)),
197-
"contentPath must point to a valid file or directory")
198199
}
199200

200201
protected def partitionToAnnotation: UserDefinedFunction = udf {

src/main/scala/com/johnsnowlabs/reader/Reader2Table.scala

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
package com.johnsnowlabs.reader
1717

1818
import com.johnsnowlabs.nlp.Annotation
19+
import com.johnsnowlabs.nlp.util.io.ResourceHelper
1920
import org.apache.spark.ml.util.{DefaultParamsReadable, Identifiable}
2021
import org.apache.spark.sql.expressions.UserDefinedFunction
2122
import org.apache.spark.sql.functions.udf
@@ -84,7 +85,7 @@ class Reader2Table(override val uid: String) extends Reader2Doc {
8485
}
8586

8687
private def getAcceptedTypes(fileName: String): Set[String] = {
87-
if (fileName.isEmpty) {
88+
if (fileName == null || fileName.isEmpty) {
8889
val officeDocTypes = Set(
8990
"application/msword",
9091
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
@@ -171,9 +172,12 @@ class Reader2Table(override val uid: String) extends Reader2Doc {
171172
}
172173

173174
override def validateRequiredParameters(): Unit = {
174-
require(
175-
$(contentPath) != null && $(contentPath).trim.nonEmpty,
176-
"contentPath must be set and not empty")
175+
val hasContentPath = $(contentPath) != null && $(contentPath).trim.nonEmpty
176+
if (hasContentPath) {
177+
require(
178+
ResourceHelper.validFile($(contentPath)),
179+
"contentPath must point to a valid file or directory")
180+
}
177181
require(
178182
Set("html-table", "json-table").contains($(outputFormat)),
179183
"outputFormat must be either 'html-table' or 'json-table'.")

src/main/scala/com/johnsnowlabs/reader/util/HTMLParser.scala

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -87,23 +87,23 @@ object HTMLParser {
8787
def tableElementToJson(tableElem: Element): String = {
8888
implicit val formats = Serialization.formats(NoTypeHints)
8989

90-
val caption = Option(tableElem.selectFirst("caption")).map(_.text.trim).getOrElse("")
90+
val caption = Option(tableElem.selectFirst("caption"))
91+
.map(_.text.trim)
92+
.getOrElse("")
9193

92-
// Headers: first row with th or td as header
93-
val headerRowOpt = tableElem
94-
.select("tr")
95-
.asScala
96-
.find(tr => tr.select("th,td").asScala.nonEmpty && tr.select("th").asScala.nonEmpty)
94+
val allRows = tableElem.select("tr").asScala.toList
95+
96+
val headerRowOpt = allRows.find(tr => tr.select("th").asScala.nonEmpty)
9797

9898
val headers: List[String] = headerRowOpt
9999
.map(_.select("th,td").asScala.map(_.text.trim).toList)
100100
.getOrElse(List.empty)
101101

102-
val allRows = tableElem.select("tr").asScala.toList
103-
val headerIndex = headerRowOpt.map(allRows.indexOf).getOrElse(0)
102+
val headerIndexOpt = headerRowOpt.map(allRows.indexOf)
103+
104104
val dataRows =
105105
allRows.zipWithIndex
106-
.filter { case (_, idx) => idx != headerIndex } // skip header row
106+
.filter { case (_, idx) => !headerIndexOpt.contains(idx) }
107107
.map(_._1)
108108
.map(row => row.select("td").asScala.map(_.text.trim).toList)
109109
.filter(_.nonEmpty)

0 commit comments

Comments
 (0)