Skip to content

Commit 93f383e

Browse files
DevinTDHaprabod
authored andcommitted
Fix pretrained models not being found on dbfs systems (#14438)
1 parent 12c1014 commit 93f383e

File tree

2 files changed

+32
-20
lines changed

2 files changed

+32
-20
lines changed

src/main/scala/com/johnsnowlabs/ml/gguf/GGUFWrapper.scala

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
package com.johnsnowlabs.ml.gguf
1717

1818
import com.johnsnowlabs.nlp.llama.{LlamaModel, ModelParameters}
19+
import com.johnsnowlabs.nlp.util.io.ResourceHelper
20+
import org.apache.hadoop.fs.{FileSystem, Path}
1921
import org.apache.spark.SparkFiles
2022
import org.apache.spark.sql.SparkSession
2123
import org.slf4j.{Logger, LoggerFactory}
@@ -72,7 +74,7 @@ object GGUFWrapper {
7274
// TODO: make sure this.synchronized is needed or it's not a bottleneck
7375
private def withSafeGGUFModelLoader(modelParameters: ModelParameters): LlamaModel =
7476
this.synchronized {
75-
new LlamaModel(modelParameters) // TODO: Model parameters
77+
new LlamaModel(modelParameters)
7678
}
7779

7880
def read(sparkSession: SparkSession, modelPath: String): GGUFWrapper = {
@@ -89,4 +91,31 @@ object GGUFWrapper {
8991

9092
new GGUFWrapper(modelFile.getName, modelFile.getParent)
9193
}
94+
95+
def readModel(modelFolderPath: String, spark: SparkSession): GGUFWrapper = {
96+
def findGGUFModelInFolder(folderPath: String): String = {
97+
val folder = new File(folderPath)
98+
if (folder.exists && folder.isDirectory) {
99+
val ggufFile: String = folder.listFiles
100+
.filter(_.isFile)
101+
.filter(_.getName.endsWith(".gguf"))
102+
.map(_.getAbsolutePath)
103+
.headOption // Should only be one file
104+
.getOrElse(
105+
throw new IllegalArgumentException(s"Could not find GGUF model in $folderPath"))
106+
107+
new File(ggufFile).getAbsolutePath
108+
} else {
109+
throw new IllegalArgumentException(s"Path $folderPath is not a directory")
110+
}
111+
}
112+
113+
val uri = new java.net.URI(modelFolderPath.replaceAllLiterally("\\", "/"))
114+
// In case the path belongs to a different file system but doesn't have the scheme prepended (e.g. dbfs)
115+
val fileSystem: FileSystem = FileSystem.get(uri, spark.sparkContext.hadoopConfiguration)
116+
val actualFolderPath = fileSystem.resolvePath(new Path(modelFolderPath)).toString
117+
val localFolder = ResourceHelper.copyToLocal(actualFolderPath)
118+
val modelFile = findGGUFModelInFolder(localFolder)
119+
read(spark, modelFile)
120+
}
92121
}

src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/AutoGGUFModel.scala

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -235,25 +235,8 @@ trait ReadAutoGGUFModel {
235235
this: ParamsAndFeaturesReadable[AutoGGUFModel] =>
236236

237237
def readModel(instance: AutoGGUFModel, path: String, spark: SparkSession): Unit = {
238-
def findGGUFModelInFolder(): String = {
239-
val folder =
240-
new java.io.File(
241-
path.replace("file:", "")
242-
) // File should be local at this point. TODO: Except if its HDFS?
243-
if (folder.exists && folder.isDirectory) {
244-
folder.listFiles
245-
.filter(_.isFile)
246-
.filter(_.getName.endsWith(".gguf"))
247-
.map(_.getAbsolutePath)
248-
.headOption // Should only be one file
249-
.getOrElse(throw new IllegalArgumentException(s"Could not find GGUF model in $path"))
250-
} else {
251-
throw new IllegalArgumentException(s"Path $path is not a directory")
252-
}
253-
}
254-
255-
val model = AutoGGUFModel.loadSavedModel(findGGUFModelInFolder(), spark)
256-
instance.setModelIfNotSet(spark, model.getModelIfNotSet)
238+
val model: GGUFWrapper = GGUFWrapper.readModel(path, spark)
239+
instance.setModelIfNotSet(spark, model)
257240
}
258241

259242
addReader(readModel)

0 commit comments

Comments
 (0)