1616package com .johnsnowlabs .ml .gguf
1717
1818import com .johnsnowlabs .nlp .llama .{LlamaModel , ModelParameters }
19+ import com .johnsnowlabs .nlp .util .io .ResourceHelper
20+ import org .apache .hadoop .fs .{FileSystem , Path }
1921import org .apache .spark .SparkFiles
2022import org .apache .spark .sql .SparkSession
2123import org .slf4j .{Logger , LoggerFactory }
@@ -72,7 +74,7 @@ object GGUFWrapper {
7274 // TODO: make sure this.synchronized is needed or it's not a bottleneck
7375 private def withSafeGGUFModelLoader (modelParameters : ModelParameters ): LlamaModel =
7476 this .synchronized {
75- new LlamaModel (modelParameters) // TODO: Model parameters
77+ new LlamaModel (modelParameters)
7678 }
7779
7880 def read (sparkSession : SparkSession , modelPath : String ): GGUFWrapper = {
@@ -89,4 +91,31 @@ object GGUFWrapper {
8991
9092 new GGUFWrapper (modelFile.getName, modelFile.getParent)
9193 }
94+
95+ def readModel (modelFolderPath : String , spark : SparkSession ): GGUFWrapper = {
96+ def findGGUFModelInFolder (folderPath : String ): String = {
97+ val folder = new File (folderPath)
98+ if (folder.exists && folder.isDirectory) {
99+ val ggufFile : String = folder.listFiles
100+ .filter(_.isFile)
101+ .filter(_.getName.endsWith(" .gguf" ))
102+ .map(_.getAbsolutePath)
103+ .headOption // Should only be one file
104+ .getOrElse(
105+ throw new IllegalArgumentException (s " Could not find GGUF model in $folderPath" ))
106+
107+ new File (ggufFile).getAbsolutePath
108+ } else {
109+ throw new IllegalArgumentException (s " Path $folderPath is not a directory " )
110+ }
111+ }
112+
113+ val uri = new java.net.URI (modelFolderPath.replaceAllLiterally(" \\ " , " /" ))
114+ // In case the path belongs to a different file system but doesn't have the scheme prepended (e.g. dbfs)
115+ val fileSystem : FileSystem = FileSystem .get(uri, spark.sparkContext.hadoopConfiguration)
116+ val actualFolderPath = fileSystem.resolvePath(new Path (modelFolderPath)).toString
117+ val localFolder = ResourceHelper .copyToLocal(actualFolderPath)
118+ val modelFile = findGGUFModelInFolder(localFolder)
119+ read(spark, modelFile)
120+ }
92121}
0 commit comments