@@ -47,9 +47,10 @@ def runTest(self):
47
47
48
48
# Use a local model path for testing - in real scenarios, use pretrained()
49
49
model_path = "/tmp/bge-reranker-v2-m3-Q4_K_M.gguf"
50
-
50
+
51
51
# Skip test if model file doesn't exist
52
52
import os
53
+
53
54
if not os .path .exists (model_path ):
54
55
self .skipTest (f"Model file not found: { model_path } " )
55
56
@@ -104,33 +105,29 @@ def runTest(self):
104
105
DocumentAssembler ().setInputCol ("text" ).setOutputCol ("document" )
105
106
)
106
107
107
- # Test with pretrained model (may not be available in test environment)
108
- try :
109
- reranker = (
110
- AutoGGUFReranker .pretrained ("bge-reranker-v2-m3-Q4_K_M" )
111
- .setInputCols ("document" )
112
- .setOutputCol ("reranked_documents" )
113
- .setBatchSize (2 )
114
- .setQuery (self .query )
115
- )
108
+ reranker = (
109
+ AutoGGUFReranker .pretrained ()
110
+ .setInputCols ("document" )
111
+ .setOutputCol ("reranked_documents" )
112
+ .setBatchSize (2 )
113
+ .setQuery (self .query )
114
+ )
116
115
117
- pipeline = Pipeline ().setStages ([document_assembler , reranker ])
118
- results = pipeline .fit (self .data ).transform (self .data )
116
+ pipeline = Pipeline ().setStages ([document_assembler , reranker ])
117
+ results = pipeline .fit (self .data ).transform (self .data )
119
118
120
- # Verify results contain relevance scores
121
- collected_results = results .collect ()
122
- for row in collected_results :
123
- annotations = row ["reranked_documents" ]
124
- for annotation in annotations :
125
- self .assertIn ("relevance_score" , annotation .metadata )
126
- # Relevance score should be a valid number
127
- score = float (annotation .metadata ["relevance_score" ])
128
- self .assertIsInstance (score , float )
119
+ # Verify results contain relevance scores
120
+ collected_results = results .collect ()
121
+ for row in collected_results :
122
+ annotations = row ["reranked_documents" ]
123
+ for annotation in annotations :
124
+ self .assertIn ("relevance_score" , annotation .metadata )
125
+ # Relevance score should be a valid number
126
+ score = float (annotation .metadata ["relevance_score" ])
127
+ self .assertIsInstance (score , float )
128
+
129
+ results .show ()
129
130
130
- results .show ()
131
- except Exception as e :
132
- # Skip if pretrained model is not available
133
- self .skipTest (f"Pretrained model not available: { str (e )} " )
134
131
135
132
@pytest .mark .slow
136
133
class AutoGGUFRerankerMetadataTestSpec (unittest .TestCase ):
@@ -139,9 +136,10 @@ def setUp(self):
139
136
140
137
def runTest (self ):
141
138
model_path = "/tmp/bge-reranker-v2-m3-Q4_K_M.gguf"
142
-
139
+
143
140
# Skip test if model file doesn't exist
144
141
import os
142
+
145
143
if not os .path .exists (model_path ):
146
144
self .skipTest (f"Model file not found: { model_path } " )
147
145
@@ -150,10 +148,11 @@ def runTest(self):
150
148
metadata = reranker .getMetadata ()
151
149
self .assertIsNotNone (metadata )
152
150
self .assertGreater (len (metadata ), 0 )
153
-
151
+
154
152
print ("Model metadata:" )
155
153
print (eval (metadata ))
156
154
155
+
157
156
#
158
157
# @pytest.mark.slow
159
158
# class AutoGGUFRerankerSerializationTestSpec(unittest.TestCase):
@@ -215,7 +214,7 @@ def runTest(self):
215
214
# results.select("reranked_documents").show(truncate=False)
216
215
217
216
218
- @pytest .mark .slow
217
+ @pytest .mark .slow
219
218
class AutoGGUFRerankerErrorHandlingTestSpec (unittest .TestCase ):
220
219
def setUp (self ):
221
220
self .spark = SparkContextForTest .spark
@@ -229,9 +228,10 @@ def runTest(self):
229
228
data = self .spark .createDataFrame ([["Test document" ]]).toDF ("text" )
230
229
231
230
model_path = "/tmp/bge-reranker-v2-m3-Q4_K_M.gguf"
232
-
231
+
233
232
# Skip test if model file doesn't exist
234
233
import os
234
+
235
235
if not os .path .exists (model_path ):
236
236
self .skipTest (f"Model file not found: { model_path } " )
237
237
@@ -244,7 +244,7 @@ def runTest(self):
244
244
)
245
245
246
246
pipeline = Pipeline ().setStages ([document_assembler , reranker ])
247
-
247
+
248
248
# This should still work with empty query (based on implementation)
249
249
try :
250
250
results = pipeline .fit (data ).transform (data )
@@ -279,9 +279,10 @@ def runTest(self):
279
279
)
280
280
281
281
model_path = "/tmp/bge-reranker-v2-m3-Q4_K_M.gguf"
282
-
282
+
283
283
# Skip test if model file doesn't exist
284
284
import os
285
+
285
286
if not os .path .exists (model_path ):
286
287
self .skipTest (f"Model file not found: { model_path } " )
287
288
@@ -322,11 +323,11 @@ def runTest(self):
322
323
self .assertIn ("rank" , annotation .metadata )
323
324
self .assertIn ("query" , annotation .metadata )
324
325
self .assertEqual (annotation .metadata ["query" ], self .query )
325
-
326
+
326
327
# Check that relevance score is normalized (due to minMaxScaling)
327
328
score = float (annotation .metadata ["relevance_score" ])
328
329
self .assertTrue (0.0 <= score <= 1.0 )
329
-
330
+
330
331
# Check that rank is a valid integer
331
332
rank = int (annotation .metadata ["rank" ])
332
333
self .assertIsInstance (rank , int )
@@ -338,7 +339,9 @@ def runTest(self):
338
339
ranks = [int (annotation .metadata ["rank" ]) for annotation in annotations ]
339
340
self .assertEqual (ranks , sorted (ranks ))
340
341
341
- print ("Pipeline with AutoGGUFReranker and GGUFRankingFinisher completed successfully" )
342
+ print (
343
+ "Pipeline with AutoGGUFReranker and GGUFRankingFinisher completed successfully"
344
+ )
342
345
results .select ("ranked_documents" ).show (truncate = False )
343
346
344
347
@@ -368,9 +371,10 @@ def runTest(self):
368
371
)
369
372
370
373
model_path = "/tmp/bge-reranker-v2-m3-Q4_K_M.gguf"
371
-
374
+
372
375
# Skip test if model file doesn't exist
373
376
import os
377
+
374
378
if not os .path .exists (model_path ):
375
379
self .skipTest (f"Model file not found: { model_path } " )
376
380
@@ -396,7 +400,7 @@ def runTest(self):
396
400
results = pipeline .fit (self .data ).transform (self .data )
397
401
398
402
collected_results = results .collect ()
399
-
403
+
400
404
# Should have at most 2 results due to topK
401
405
self .assertLessEqual (len (collected_results ), 2 )
402
406
@@ -407,7 +411,7 @@ def runTest(self):
407
411
# Check normalized scores are >= 0.1 threshold
408
412
score = float (annotation .metadata ["relevance_score" ])
409
413
self .assertTrue (0.1 <= score <= 1.0 )
410
-
414
+
411
415
# Check rank metadata exists
412
416
self .assertIn ("rank" , annotation .metadata )
413
417
rank = int (annotation .metadata ["rank" ])
0 commit comments