@@ -162,7 +162,7 @@ class NerDLGraphChecker(override val uid: String)
162
162
new Param [String ](this , " graphFolder" , " Folder path that contain external graph files" )
163
163
164
164
/** @group getParam */
165
- private def getGraphFolder : Option [String ] = get(graphFolder)
165
+ protected def getGraphFolder : Option [String ] = get(graphFolder)
166
166
167
167
/** Extracts the graph hyperparameters from the training data (dataset).
168
168
*
@@ -177,7 +177,7 @@ class NerDLGraphChecker(override val uid: String)
177
177
* a tuple containing the number of labels, number of unique characters, and the embedding
178
178
* dim
179
179
*/
180
- private def getGraphParamsDs (
180
+ protected def getGraphParamsDs (
181
181
dataset : Dataset [_],
182
182
inputCols : Array [String ],
183
183
labelsCol : String ): (Int , Int , Int ) = {
@@ -219,14 +219,16 @@ class NerDLGraphChecker(override val uid: String)
219
219
(nLabels, nChars, embeddingsDim)
220
220
}
221
221
222
+ protected def searchForSuitableGraph (nLabels : Int , nChars : Int , embeddingsDim : Int ): String =
223
+ NerDLApproach .searchForSuitableGraph(nLabels, embeddingsDim, nChars + 1 , getGraphFolder)
224
+
222
225
override def fit (dataset : Dataset [_]): NerDLGraphCheckerModel = {
223
226
val (nLabels, nChars, embeddingsDim) =
224
227
getGraphParamsDs(dataset, $(inputCols), $(labelColumn))
225
228
226
229
// Throws exception if no suitable graph found
227
230
Try {
228
- NerDLApproach
229
- .searchForSuitableGraph(nLabels, embeddingsDim, nChars + 1 , getGraphFolder)
231
+ searchForSuitableGraph(nLabels, nChars, embeddingsDim)
230
232
} match {
231
233
case Failure (exception : IllegalArgumentException ) =>
232
234
throw new IllegalArgumentException (" NerDLGraphChecker: " + exception.getMessage)
0 commit comments