Skip to content

Commit 3b8cc59

Browse files
committed
searchForSuitableGraph protected function
1 parent ae38d21 commit 3b8cc59

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

src/main/scala/com/johnsnowlabs/nlp/annotators/ner/dl/NerDLGraphChecker.scala

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ class NerDLGraphChecker(override val uid: String)
162162
new Param[String](this, "graphFolder", "Folder path that contain external graph files")
163163

164164
/** @group getParam */
165-
private def getGraphFolder: Option[String] = get(graphFolder)
165+
protected def getGraphFolder: Option[String] = get(graphFolder)
166166

167167
/** Extracts the graph hyperparameters from the training data (dataset).
168168
*
@@ -177,7 +177,7 @@ class NerDLGraphChecker(override val uid: String)
177177
* a tuple containing the number of labels, number of unique characters, and the embedding
178178
* dim
179179
*/
180-
private def getGraphParamsDs(
180+
protected def getGraphParamsDs(
181181
dataset: Dataset[_],
182182
inputCols: Array[String],
183183
labelsCol: String): (Int, Int, Int) = {
@@ -219,14 +219,16 @@ class NerDLGraphChecker(override val uid: String)
219219
(nLabels, nChars, embeddingsDim)
220220
}
221221

222+
protected def searchForSuitableGraph(nLabels: Int, nChars: Int, embeddingsDim: Int): String =
223+
NerDLApproach.searchForSuitableGraph(nLabels, embeddingsDim, nChars + 1, getGraphFolder)
224+
222225
override def fit(dataset: Dataset[_]): NerDLGraphCheckerModel = {
223226
val (nLabels, nChars, embeddingsDim) =
224227
getGraphParamsDs(dataset, $(inputCols), $(labelColumn))
225228

226229
// Throws exception if no suitable graph found
227230
Try {
228-
NerDLApproach
229-
.searchForSuitableGraph(nLabels, embeddingsDim, nChars + 1, getGraphFolder)
231+
searchForSuitableGraph(nLabels, nChars, embeddingsDim)
230232
} match {
231233
case Failure(exception: IllegalArgumentException) =>
232234
throw new IllegalArgumentException("NerDLGraphChecker: " + exception.getMessage)

0 commit comments

Comments
 (0)