Skip to content

Commit 7a2698f

Browse files
committed
[SPARKNLP-1096] Adding support to Microsoft Fabric for WordEmbeddings storage index
1 parent 6653546 commit 7a2698f

File tree

6 files changed

+147
-31
lines changed

6 files changed

+147
-31
lines changed

src/main/scala/com/johnsnowlabs/client/util/CloudHelper.scala

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
*/
1616
package com.johnsnowlabs.client.util
1717

18-
import com.johnsnowlabs.nlp.util.io.CloudStorageType
1918
import com.johnsnowlabs.nlp.util.io.CloudStorageType.CloudStorageType
19+
import com.johnsnowlabs.nlp.util.io.{CloudStorageType, ResourceHelper}
2020

2121
import java.net.{URI, URL}
2222

@@ -71,7 +71,8 @@ object CloudHelper {
7171
}
7272

7373
def isCloudPath(uri: String): Boolean = {
74-
isS3Path(uri) || isGCPStoragePath(uri) || isAzureBlobPath(uri)
74+
val intraCloudPath = isIntraCloudPath(uri)
75+
(isS3Path(uri) || isGCPStoragePath(uri) || isAzureBlobPath(uri)) && !intraCloudPath
7576
}
7677

7778
def isS3Path(uri: String): Boolean = {
@@ -81,7 +82,16 @@ object CloudHelper {
8182
private def isGCPStoragePath(uri: String): Boolean = uri.startsWith("gs://")
8283

8384
private def isAzureBlobPath(uri: String): Boolean = {
84-
uri.startsWith("https://") && uri.contains(".blob.core.windows.net/")
85+
(uri.startsWith("https://") && uri.contains(".blob.core.windows.net/")) || uri.startsWith(
86+
"abfss://")
87+
}
88+
89+
private def isIntraCloudPath(uri: String): Boolean = {
90+
uri.startsWith("abfss://") && isMicrosoftFabric
91+
}
92+
93+
def isMicrosoftFabric: Boolean = {
94+
ResourceHelper.spark.conf.getAll.keys.exists(_.startsWith("spark.fabric"))
8595
}
8696

8797
def cloudType(uri: String): CloudStorageType = {

src/main/scala/com/johnsnowlabs/storage/RocksDBConnection.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,11 @@ final class RocksDBConnection private (path: String) extends AutoCloseable {
4343

4444
def findLocalIndex: String = {
4545
val tmpIndexStorageLocalPath = RocksDBConnection.getTmpIndexStorageLocalPath(path)
46-
if (new File(tmpIndexStorageLocalPath).exists()) {
46+
val tmpIndexStorageLocalPathExists = new File(tmpIndexStorageLocalPath).exists()
47+
val pathExist = new File(path.stripPrefix("file:")).exists()
48+
if (tmpIndexStorageLocalPathExists) {
4749
tmpIndexStorageLocalPath
48-
} else if (new File(path).exists()) {
50+
} else if (pathExist) {
4951
path
5052
} else {
5153
val localFromClusterPath = SparkFiles.get(path)

src/main/scala/com/johnsnowlabs/storage/StorageHelper.scala

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
package com.johnsnowlabs.storage
1818

1919
import com.johnsnowlabs.client.CloudResources
20+
import com.johnsnowlabs.client.util.CloudHelper
2021
import org.apache.hadoop.fs.{FileSystem, FileUtil, Path}
2122
import org.apache.spark.sql.SparkSession
2223
import org.apache.spark.{SparkContext, SparkFiles}
@@ -34,7 +35,6 @@ object StorageHelper {
3435
database: String,
3536
storageRef: String,
3637
withinStorage: Boolean): RocksDBConnection = {
37-
3838
val dbFolder = StorageHelper.resolveStorageName(database, storageRef)
3939
val source = StorageLocator.getStorageSerializedPath(
4040
storageSourcePath.replaceAllLiterally("\\", "/"),
@@ -49,7 +49,11 @@ object StorageHelper {
4949
locator.destinationScheme,
5050
spark.sparkContext)
5151

52-
RocksDBConnection.getOrCreate(locator.clusterFileName)
52+
val storagePath = if (locator.clusterFilePath.toString.startsWith("file:")) {
53+
locator.clusterFilePath.toString
54+
} else locator.clusterFileName
55+
56+
RocksDBConnection.getOrCreate(storagePath)
5357
}
5458

5559
def save(
@@ -96,9 +100,19 @@ object StorageHelper {
96100
}
97101
case "s3a" =>
98102
copyIndexToLocal(source, new Path(tmpIndexStorageLocalPath), sparkContext)
99-
case _ => copyIndexToCluster(source, clusterFilePath, sparkContext)
103+
case _ => {
104+
copyIndexToCluster(source, clusterFilePath, sparkContext)
105+
}
100106
}
101107
}
108+
case "abfss" =>
109+
if (clusterFilePath.toString.startsWith("file:")) {
110+
val tmpIndexStorageLocalPath =
111+
RocksDBConnection.getTmpIndexStorageLocalPath(clusterFileName)
112+
copyIndexToCluster(source, new Path("file://" + tmpIndexStorageLocalPath), sparkContext)
113+
} else {
114+
copyIndexToLocal(source, clusterFilePath, sparkContext)
115+
}
102116
case _ => {
103117
copyIndexToCluster(source, clusterFilePath, sparkContext)
104118
}
@@ -120,7 +134,8 @@ object StorageHelper {
120134
sourcePath: Path,
121135
dst: Path,
122136
sparkContext: SparkContext): String = {
123-
if (!new File(SparkFiles.get(dst.getName)).exists()) {
137+
val destinationInSpark = new File(SparkFiles.get(dst.getName)).exists()
138+
if (!destinationInSpark) {
124139
val srcFS = sourcePath.getFileSystem(sparkContext.hadoopConfiguration)
125140
val dstFS = dst.getFileSystem(sparkContext.hadoopConfiguration)
126141

@@ -138,7 +153,9 @@ object StorageHelper {
138153
sparkContext.hadoopConfiguration)
139154
}
140155

141-
sparkContext.addFile(dst.toString, recursive = true)
156+
if (!CloudHelper.isMicrosoftFabric) {
157+
sparkContext.addFile(dst.toString, recursive = true)
158+
}
142159
}
143160
dst.toString
144161
}

src/main/scala/com/johnsnowlabs/storage/StorageLocator.scala

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -29,29 +29,47 @@ case class StorageLocator(database: String, storageRef: String, sparkSession: Sp
2929
if (tmpLocation.matches("s3[a]?:/.*")) {
3030
tmpLocation
3131
} else {
32-
val tmpLocationPath = new Path(tmpLocation)
33-
fileSystem.mkdirs(tmpLocationPath)
34-
fileSystem.deleteOnExit(tmpLocationPath)
35-
tmpLocation
32+
fileSystem.getScheme match {
33+
case "abfss" =>
34+
if (tmpLocation.startsWith("abfss:")) {
35+
tmpLocation
36+
} else {
37+
"file:///" + tmpLocation
38+
}
39+
case _ =>
40+
val tmpLocationPath = new Path(tmpLocation)
41+
fileSystem.mkdirs(tmpLocationPath)
42+
fileSystem.deleteOnExit(tmpLocationPath)
43+
tmpLocation
44+
}
3645
}
3746
}
3847

39-
val clusterFileName: String = {
40-
StorageHelper.resolveStorageName(database, storageRef)
41-
}
48+
val clusterFileName: String = { StorageHelper.resolveStorageName(database, storageRef) }
4249

4350
val clusterFilePath: Path = {
4451
if (!getTmpLocation.matches("s3[a]?:/.*")) {
4552
val scheme = Option(new Path(clusterTmpLocation).toUri.getScheme).getOrElse("")
4653
scheme match {
47-
case "dbfs" | "hdfs" =>
48-
Path.mergePaths(new Path(clusterTmpLocation), new Path("/" + clusterFileName))
49-
case _ =>
50-
Path.mergePaths(
51-
new Path(fileSystem.getUri.toString + clusterTmpLocation),
52-
new Path("/" + clusterFileName))
54+
case "dbfs" | "hdfs" => mergePaths()
55+
case "file" =>
56+
val uri = fileSystem.getUri.toString
57+
if (uri.startsWith("abfss:")) { mergePaths() }
58+
else { mergePaths(withFileSystem = true) }
59+
case "abfss" => mergePaths()
60+
case _ => mergePaths(withFileSystem = true)
5361
}
54-
} else new Path(clusterTmpLocation + "/" + clusterFileName)
62+
} else {
63+
new Path(clusterTmpLocation + "/" + clusterFileName)
64+
}
65+
}
66+
67+
private def mergePaths(withFileSystem: Boolean = false): Path = {
68+
if (withFileSystem) {
69+
Path.mergePaths(
70+
new Path(fileSystem.getUri.toString + clusterTmpLocation),
71+
new Path("/" + clusterFileName))
72+
} else Path.mergePaths(new Path(clusterTmpLocation), new Path("/" + clusterFileName))
5573
}
5674

5775
val destinationScheme: String = fileSystem.getScheme

src/main/scala/com/johnsnowlabs/util/Version.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,10 @@ object Version {
4848
def parse(str: String): Version = {
4949
val parts = str
5050
.replaceAll("-rc\\d", "")
51-
.split('.')
51+
.split("[.-]")
5252
.takeWhile(p => isInteger(p))
5353
.map(p => p.toInt)
54+
.take(3)
5455
.toList
5556

5657
Version(parts)

src/test/scala/com/johnsnowlabs/nlp/util/VersionTest.scala

Lines changed: 74 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,14 @@
1616

1717
package com.johnsnowlabs.nlp.util
1818

19+
import com.johnsnowlabs.tags.FastTest
1920
import com.johnsnowlabs.util.Version
2021
import org.junit.Assert.{assertFalse, assertTrue}
2122
import org.scalatest.flatspec.AnyFlatSpec
2223

2324
class VersionTest extends AnyFlatSpec {
2425

25-
"Version" should "cast to float version of 1 digit" in {
26+
"Version" should "cast to float version of 1 digit" taggedAs FastTest in {
2627

2728
val actualVersion1 = Version(1).toFloat
2829
val actualVersion15 = Version(15).toFloat
@@ -32,15 +33,15 @@ class VersionTest extends AnyFlatSpec {
3233

3334
}
3435

35-
it should "cast to float version of 2 digits" in {
36+
it should "cast to float version of 2 digits" taggedAs FastTest in {
3637
val actualVersion1_2 = Version(List(1, 2)).toFloat
3738
val actualVersion2_7 = Version(List(2, 7)).toFloat
3839

3940
assert(actualVersion1_2 == 1.2f)
4041
assert(actualVersion2_7 == 2.7f)
4142
}
4243

43-
it should "cast to float version of 3 digits" in {
44+
it should "cast to float version of 3 digits" taggedAs FastTest in {
4445
val actualVersion1_2_5 = Version(List(1, 2, 5)).toFloat
4546
val actualVersion3_2_0 = Version(List(3, 2, 0)).toFloat
4647
val actualVersion2_0_6 = Version(List(2, 0, 6)).toFloat
@@ -50,13 +51,13 @@ class VersionTest extends AnyFlatSpec {
5051
assert(actualVersion2_0_6 == 2.06f)
5152
}
5253

53-
it should "raise error when casting to float version > 3 digits" in {
54+
it should "raise error when casting to float version > 3 digits" taggedAs FastTest in {
5455
assertThrows[UnsupportedOperationException] {
5556
Version(List(3, 0, 2, 5)).toFloat
5657
}
5758
}
5859

59-
it should "be compatible for latest versions" in {
60+
it should "be compatible for latest versions" taggedAs FastTest in {
6061
var currentVersion = Version(List(1, 2, 3))
6162
var modelVersion = Version(List(1, 2))
6263

@@ -80,7 +81,7 @@ class VersionTest extends AnyFlatSpec {
8081

8182
}
8283

83-
it should "be not compatible for latest versions" in {
84+
it should "be not compatible for latest versions" taggedAs FastTest in {
8485
var currentVersion = Version(List(1, 2))
8586
var modelVersion = Version(List(1, 2, 3))
8687

@@ -103,4 +104,71 @@ class VersionTest extends AnyFlatSpec {
103104
assertFalse(isNotCompatible)
104105
}
105106

107+
it should "parse a version with fewer than 3 numbers" taggedAs FastTest in {
108+
val someVersion = "3.2"
109+
val expectedVersion = "3.2"
110+
val expectedFloatVersion = 3.2f
111+
val actualVersion = Version.parse(someVersion)
112+
113+
assert(expectedVersion == actualVersion.toString)
114+
assert(expectedFloatVersion == actualVersion.toFloat)
115+
}
116+
117+
it should "parse a version with 3 numbers" taggedAs FastTest in {
118+
val someVersion = "3.4.2"
119+
val expectedFloatVersion = 3.42f
120+
val actualVersion = Version.parse(someVersion)
121+
122+
assert(someVersion == actualVersion.toString)
123+
assert(expectedFloatVersion == actualVersion.toFloat)
124+
}
125+
126+
it should "truncate a version to 3 digits when it has more than 3 digits" taggedAs FastTest in {
127+
val someVersion = "3.5.1.5.4.20241007.4"
128+
val expectedVersion = "3.5.1"
129+
val expectedFloatVersion = 3.51f
130+
val actualVersion = Version.parse(someVersion)
131+
132+
assert(expectedVersion == actualVersion.toString)
133+
assert(expectedFloatVersion == actualVersion.toFloat)
134+
}
135+
136+
it should "handle a version with missing parts" taggedAs FastTest in {
137+
val someVersion = "3"
138+
val expectedVersion = "3"
139+
val expectedFloatVersion = 3.0f
140+
val actualVersion = Version.parse(someVersion)
141+
142+
assert(expectedVersion == actualVersion.toString)
143+
assert(expectedFloatVersion == actualVersion.toFloat)
144+
}
145+
146+
it should "handle a version with 3 digits and additional suffix" taggedAs FastTest in {
147+
val someVersion = "3.4.2-beta"
148+
val expectedVersion = "3.4.2"
149+
val expectedFloatVersion = 3.42f
150+
val actualVersion = Version.parse(someVersion)
151+
152+
assert(expectedVersion == actualVersion.toString)
153+
assert(expectedFloatVersion == actualVersion.toFloat)
154+
}
155+
156+
it should "raise exception with non-numeric and no valid parts" taggedAs FastTest in {
157+
val someVersion = "alpha.beta.gamma"
158+
159+
assertThrows[UnsupportedOperationException] {
160+
Version.parse(someVersion).toFloat
161+
}
162+
}
163+
164+
it should "handle a version with mixed numeric and non-numeric parts" taggedAs FastTest in {
165+
val someVersion = "3.4-alpha.2"
166+
val expectedVersion = "3.4"
167+
val expectedFloatVersion = 3.4f
168+
val actualVersion = Version.parse(someVersion)
169+
170+
assert(expectedVersion == actualVersion.toString)
171+
assert(expectedFloatVersion == actualVersion.toFloat)
172+
}
173+
106174
}

0 commit comments

Comments
 (0)