Skip to content

Commit 22ed751

Browse files
committed
Change pretrained AutoGGUFReranking model
1 parent a4abd0f commit 22ed751

File tree

10 files changed

+122
-97
lines changed

10 files changed

+122
-97
lines changed

docs/en/annotator_entries/AutoGGUFReranker.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ val reranker = AutoGGUFReranker.pretrained()
3333
.setQuery("A man is eating pasta.")
3434
```
3535

36-
The default model is `"bge-reranker-v2-m3-Q4_K_M"`, if no name is provided.
36+
The default model is `"bge_reranker_v2_m3_Q4_K_M"`, if no name is provided.
3737

3838
For available pretrained models please see the [Models Hub](https://sparknlp.org/models).
3939

@@ -105,7 +105,7 @@ val document = new DocumentAssembler()
105105
.setOutputCol("document")
106106

107107
val reranker = AutoGGUFReranker
108-
.pretrained("bge-reranker-v2-m3-Q4_K_M")
108+
.pretrained()
109109
.setInputCols("document")
110110
.setOutputCol("reranked_documents")
111111
.setBatchSize(4)

docs/en/annotator_entries/GGUFRankingFinisher.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ val documentAssembler = new DocumentAssembler()
8585

8686
// Reranker
8787
val reranker = AutoGGUFReranker
88-
.pretrained("bge-reranker-v2-m3-Q4_K_M")
88+
.pretrained()
8989
.setInputCols("document")
9090
.setOutputCol("reranked_documents")
9191
.setQuery("A man is eating pasta.")

examples/python/llama.cpp/GGUFRankingFinisher_for_AutoGGUFReranker.ipynb

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -136,10 +136,7 @@
136136
"document_assembler = DocumentAssembler().setInputCol(\"text\").setOutputCol(\"document\")\n",
137137
"\n",
138138
"auto_gguf_model = (\n",
139-
" AutoGGUFReranker.loadSavedModel(\n",
140-
" \"/home/ducha/Workspace/scala/spark-nlp-release/tmp_autogguf_reranker/bge-reranker-v2-m3-q4_k_m.gguf\",\n",
141-
" spark,\n",
142-
" )\n",
139+
" AutoGGUFReranker.pretrained()\n",
143140
" .setInputCols(\"document\")\n",
144141
" .setOutputCol(\"reranked_documents\")\n",
145142
" .setQuery(\"A man is eating pasta.\")\n",

python/sparknlp/annotator/seq2seq/auto_gguf_reranker.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ class AutoGGUFReranker(AnnotatorModel, HasBatchedAnnotate, HasLlamaCppProperties
4747
... .setOutputCol("reranked_documents") \\
4848
... .setQuery("A man is eating pasta.")
4949
50-
The default model is ``"bge-reranker-v2-m3-Q4_K_M"``, if no name is provided.
50+
The default model is ``"bge_reranker_v2_m3_Q4_K_M"``, if no name is provided.
5151
5252
For extended examples of usage, see the
5353
`AutoGGUFRerankerTest <https://github.com/JohnSnowLabs/spark-nlp/tree/master/src/test/scala/com/johnsnowlabs/nlp/annotators/seq2seq/AutoGGUFRerankerTest.scala>`__
@@ -222,7 +222,7 @@ class AutoGGUFReranker(AnnotatorModel, HasBatchedAnnotate, HasLlamaCppProperties
222222
>>> document = DocumentAssembler() \\
223223
... .setInputCol("text") \\
224224
... .setOutputCol("document")
225-
>>> reranker = AutoGGUFReranker.pretrained("bge-reranker-v2-m3-Q4_K_M") \\
225+
>>> reranker = AutoGGUFReranker.pretrained() \\
226226
... .setInputCols(["document"]) \\
227227
... .setOutputCol("reranked_documents") \\
228228
... .setBatchSize(4) \\
@@ -307,13 +307,13 @@ def loadSavedModel(folder, spark_session):
307307
return AutoGGUFReranker(java_model=jModel)
308308

309309
@staticmethod
310-
def pretrained(name="bge-reranker-v2-m3-Q4_K_M", lang="en", remote_loc=None):
310+
def pretrained(name="bge_reranker_v2_m3_Q4_K_M", lang="en", remote_loc=None):
311311
"""Downloads and loads a pretrained model.
312312
313313
Parameters
314314
----------
315315
name : str, optional
316-
Name of the pretrained model, by default "bge-reranker-v2-m3-Q4_K_M"
316+
Name of the pretrained model, by default "bge_reranker_v2_m3_Q4_K_M"
317317
lang : str, optional
318318
Language of the pretrained model, by default "en"
319319
remote_loc : str, optional

python/sparknlp/base/gguf_ranking_finisher.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ class GGUFRankingFinisher(AnnotatorTransformer):
6565
>>> documentAssembler = DocumentAssembler() \\
6666
... .setInputCol("text") \\
6767
... .setOutputCol("document")
68-
>>> reranker = AutoGGUFReranker.pretrained("bge-reranker-v2-m3-Q4_K_M") \\
68+
>>> reranker = AutoGGUFReranker.pretrained() \\
6969
... .setInputCols("document") \\
7070
... .setOutputCol("reranked_documents") \\
7171
... .setQuery("A man is eating pasta.")

python/test/annotator/seq2seq/auto_gguf_reranker_test.py

Lines changed: 41 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,10 @@ def runTest(self):
4747

4848
# Use a local model path for testing - in real scenarios, use pretrained()
4949
model_path = "/tmp/bge-reranker-v2-m3-Q4_K_M.gguf"
50-
50+
5151
# Skip test if model file doesn't exist
5252
import os
53+
5354
if not os.path.exists(model_path):
5455
self.skipTest(f"Model file not found: {model_path}")
5556

@@ -104,33 +105,29 @@ def runTest(self):
104105
DocumentAssembler().setInputCol("text").setOutputCol("document")
105106
)
106107

107-
# Test with pretrained model (may not be available in test environment)
108-
try:
109-
reranker = (
110-
AutoGGUFReranker.pretrained("bge-reranker-v2-m3-Q4_K_M")
111-
.setInputCols("document")
112-
.setOutputCol("reranked_documents")
113-
.setBatchSize(2)
114-
.setQuery(self.query)
115-
)
108+
reranker = (
109+
AutoGGUFReranker.pretrained()
110+
.setInputCols("document")
111+
.setOutputCol("reranked_documents")
112+
.setBatchSize(2)
113+
.setQuery(self.query)
114+
)
116115

117-
pipeline = Pipeline().setStages([document_assembler, reranker])
118-
results = pipeline.fit(self.data).transform(self.data)
116+
pipeline = Pipeline().setStages([document_assembler, reranker])
117+
results = pipeline.fit(self.data).transform(self.data)
119118

120-
# Verify results contain relevance scores
121-
collected_results = results.collect()
122-
for row in collected_results:
123-
annotations = row["reranked_documents"]
124-
for annotation in annotations:
125-
self.assertIn("relevance_score", annotation.metadata)
126-
# Relevance score should be a valid number
127-
score = float(annotation.metadata["relevance_score"])
128-
self.assertIsInstance(score, float)
119+
# Verify results contain relevance scores
120+
collected_results = results.collect()
121+
for row in collected_results:
122+
annotations = row["reranked_documents"]
123+
for annotation in annotations:
124+
self.assertIn("relevance_score", annotation.metadata)
125+
# Relevance score should be a valid number
126+
score = float(annotation.metadata["relevance_score"])
127+
self.assertIsInstance(score, float)
128+
129+
results.show()
129130

130-
results.show()
131-
except Exception as e:
132-
# Skip if pretrained model is not available
133-
self.skipTest(f"Pretrained model not available: {str(e)}")
134131

135132
@pytest.mark.slow
136133
class AutoGGUFRerankerMetadataTestSpec(unittest.TestCase):
@@ -139,9 +136,10 @@ def setUp(self):
139136

140137
def runTest(self):
141138
model_path = "/tmp/bge-reranker-v2-m3-Q4_K_M.gguf"
142-
139+
143140
# Skip test if model file doesn't exist
144141
import os
142+
145143
if not os.path.exists(model_path):
146144
self.skipTest(f"Model file not found: {model_path}")
147145

@@ -150,10 +148,11 @@ def runTest(self):
150148
metadata = reranker.getMetadata()
151149
self.assertIsNotNone(metadata)
152150
self.assertGreater(len(metadata), 0)
153-
151+
154152
print("Model metadata:")
155153
print(eval(metadata))
156154

155+
157156
#
158157
# @pytest.mark.slow
159158
# class AutoGGUFRerankerSerializationTestSpec(unittest.TestCase):
@@ -215,7 +214,7 @@ def runTest(self):
215214
# results.select("reranked_documents").show(truncate=False)
216215

217216

218-
@pytest.mark.slow
217+
@pytest.mark.slow
219218
class AutoGGUFRerankerErrorHandlingTestSpec(unittest.TestCase):
220219
def setUp(self):
221220
self.spark = SparkContextForTest.spark
@@ -229,9 +228,10 @@ def runTest(self):
229228
data = self.spark.createDataFrame([["Test document"]]).toDF("text")
230229

231230
model_path = "/tmp/bge-reranker-v2-m3-Q4_K_M.gguf"
232-
231+
233232
# Skip test if model file doesn't exist
234233
import os
234+
235235
if not os.path.exists(model_path):
236236
self.skipTest(f"Model file not found: {model_path}")
237237

@@ -244,7 +244,7 @@ def runTest(self):
244244
)
245245

246246
pipeline = Pipeline().setStages([document_assembler, reranker])
247-
247+
248248
# This should still work with empty query (based on implementation)
249249
try:
250250
results = pipeline.fit(data).transform(data)
@@ -279,9 +279,10 @@ def runTest(self):
279279
)
280280

281281
model_path = "/tmp/bge-reranker-v2-m3-Q4_K_M.gguf"
282-
282+
283283
# Skip test if model file doesn't exist
284284
import os
285+
285286
if not os.path.exists(model_path):
286287
self.skipTest(f"Model file not found: {model_path}")
287288

@@ -322,11 +323,11 @@ def runTest(self):
322323
self.assertIn("rank", annotation.metadata)
323324
self.assertIn("query", annotation.metadata)
324325
self.assertEqual(annotation.metadata["query"], self.query)
325-
326+
326327
# Check that relevance score is normalized (due to minMaxScaling)
327328
score = float(annotation.metadata["relevance_score"])
328329
self.assertTrue(0.0 <= score <= 1.0)
329-
330+
330331
# Check that rank is a valid integer
331332
rank = int(annotation.metadata["rank"])
332333
self.assertIsInstance(rank, int)
@@ -338,7 +339,9 @@ def runTest(self):
338339
ranks = [int(annotation.metadata["rank"]) for annotation in annotations]
339340
self.assertEqual(ranks, sorted(ranks))
340341

341-
print("Pipeline with AutoGGUFReranker and GGUFRankingFinisher completed successfully")
342+
print(
343+
"Pipeline with AutoGGUFReranker and GGUFRankingFinisher completed successfully"
344+
)
342345
results.select("ranked_documents").show(truncate=False)
343346

344347

@@ -368,9 +371,10 @@ def runTest(self):
368371
)
369372

370373
model_path = "/tmp/bge-reranker-v2-m3-Q4_K_M.gguf"
371-
374+
372375
# Skip test if model file doesn't exist
373376
import os
377+
374378
if not os.path.exists(model_path):
375379
self.skipTest(f"Model file not found: {model_path}")
376380

@@ -396,7 +400,7 @@ def runTest(self):
396400
results = pipeline.fit(self.data).transform(self.data)
397401

398402
collected_results = results.collect()
399-
403+
400404
# Should have at most 2 results due to topK
401405
self.assertLessEqual(len(collected_results), 2)
402406

@@ -407,7 +411,7 @@ def runTest(self):
407411
# Check normalized scores are >= 0.1 threshold
408412
score = float(annotation.metadata["relevance_score"])
409413
self.assertTrue(0.1 <= score <= 1.0)
410-
414+
411415
# Check rank metadata exists
412416
self.assertIn("rank", annotation.metadata)
413417
rank = int(annotation.metadata["rank"])

0 commit comments

Comments
 (0)