-
Notifications
You must be signed in to change notification settings - Fork 541
Expand file tree
/
Copy pathtest_multi_model.py
More file actions
254 lines (202 loc) · 8.08 KB
/
test_multi_model.py
File metadata and controls
254 lines (202 loc) · 8.08 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
"""
Simple test script to verify multi-model functionality in mistral.rs
This script tests the core multi-model operations:
- Listing models
- Getting/setting default model
- Sending requests to specific models
- Model unloading/reloading
- Model removal (commented out for safety)
"""
from mistralrs import (
Runner,
Which,
ChatCompletionRequest,
Architecture,
)
import sys
def test_multi_model_operations():
"""Test basic multi-model operations."""
print("Testing Multi-Model Operations\n" + "=" * 50)
try:
# Create a simple runner
print("1. Creating runner with GPT-2 model...")
runner = Runner(
which=Which.Plain(
model_id="gpt2",
arch=Architecture.Gpt2,
)
)
print(" ✓ Runner created successfully")
# Test listing models
print("\n2. Testing list_models()...")
models = runner.list_models()
print(f" ✓ Models found: {models}")
assert isinstance(models, list), "list_models should return a list"
assert len(models) > 0, "Should have at least one model"
# Test getting default model
print("\n3. Testing get_default_model_id()...")
default_model = runner.get_default_model_id()
print(f" ✓ Default model: {default_model}")
# Test setting default model (if we have multiple models)
if len(models) > 1:
print("\n4. Testing set_default_model_id()...")
new_default = models[1] if models[0] == default_model else models[0]
runner.set_default_model_id(new_default)
updated_default = runner.get_default_model_id()
print(f" ✓ Changed default from '{default_model}' to '{updated_default}'")
assert updated_default == new_default, "Default model should have changed"
else:
print("\n4. Skipping set_default_model_id() test (only one model loaded)")
# Test sending request with model_id
print("\n5. Testing send_chat_completion_request with model_id...")
messages = [
{"role": "user", "content": "Say 'test successful' and nothing else."}
]
request = ChatCompletionRequest(messages=messages, max_tokens=10)
if models:
response = runner.send_chat_completion_request(
request=request, model_id=models[0]
)
print(f" ✓ Response received: {response.choices[0].message.content}")
# Test list_models_with_status
print("\n6. Testing list_models_with_status()...")
models_with_status = runner.list_models_with_status()
print(f" ✓ Models with status: {models_with_status}")
assert isinstance(models_with_status, list), "Should return a list"
# Test is_model_loaded
print("\n7. Testing is_model_loaded()...")
if models:
is_loaded = runner.is_model_loaded(models[0])
print(f" ✓ Model '{models[0]}' loaded: {is_loaded}")
assert is_loaded, "Model should be loaded initially"
print("\n✅ All tests passed!")
except Exception as e:
print(f"\n❌ Test failed with error: {e}")
import traceback
traceback.print_exc()
return False
return True
def test_model_id_in_requests():
"""Test that model_id is properly passed in requests."""
print("\n\nTesting Model ID in Requests\n" + "=" * 50)
try:
runner = Runner(
which=Which.Plain(
model_id="gpt2",
arch=Architecture.Gpt2,
)
)
models = runner.list_models()
if not models:
print("No models available to test")
return False
model_id = models[0]
print(f"Using model: {model_id}")
# Test different request types with model_id
messages = [{"role": "user", "content": "Hi"}]
# Chat completion
print("\n1. Testing chat completion with model_id...")
request = ChatCompletionRequest(messages=messages, max_tokens=5)
response = runner.send_chat_completion_request(request, model_id=model_id)
print(f" ✓ Chat response: {response.choices[0].message.content}")
# Without model_id (should use default)
print("\n2. Testing chat completion without model_id...")
response = runner.send_chat_completion_request(request)
print(f" ✓ Chat response (default): {response.choices[0].message.content}")
print("\n✅ Model ID request tests passed!")
return True
except Exception as e:
print(f"\n❌ Test failed with error: {e}")
import traceback
traceback.print_exc()
return False
def test_unload_reload():
"""Test model unloading and reloading."""
print("\n\nTesting Model Unload/Reload\n" + "=" * 50)
try:
runner = Runner(
which=Which.Plain(
model_id="gpt2",
arch=Architecture.Gpt2,
)
)
models = runner.list_models()
if not models:
print("No models available to test")
return False
model_id = models[0]
# Check initial status
print("1. Checking initial status...")
assert runner.is_model_loaded(model_id), "Model should be loaded initially"
print(f" ✓ Model '{model_id}' is loaded")
# Unload the model
print("\n2. Unloading model...")
runner.unload_model(model_id)
is_loaded = runner.is_model_loaded(model_id)
print(f" ✓ Model unloaded. is_model_loaded: {is_loaded}")
# Check status after unload
print("\n3. Checking status after unload...")
status = runner.list_models_with_status()
print(f" ✓ Status: {status}")
# Reload the model
print("\n4. Reloading model...")
runner.reload_model(model_id)
is_loaded = runner.is_model_loaded(model_id)
print(f" ✓ Model reloaded. is_model_loaded: {is_loaded}")
assert is_loaded, "Model should be loaded after reload"
print("\n✅ Unload/reload tests passed!")
return True
except Exception as e:
print(f"\n❌ Test failed with error: {e}")
import traceback
traceback.print_exc()
return False
def test_error_handling():
"""Test error handling for multi-model operations."""
print("\n\nTesting Error Handling\n" + "=" * 50)
try:
runner = Runner(
which=Which.Plain(
model_id="gpt2",
arch=Architecture.Gpt2,
)
)
# Test with non-existent model
print("1. Testing request to non-existent model...")
messages = [{"role": "user", "content": "Hi"}]
request = ChatCompletionRequest(messages=messages)
try:
response = runner.send_chat_completion_request(
request, model_id="non-existent-model"
)
print(" ❌ Should have raised an error for non-existent model")
except Exception as e:
print(f" ✓ Correctly raised error: {type(e).__name__}")
# Test setting non-existent model as default
print("\n2. Testing set_default_model_id with non-existent model...")
try:
runner.set_default_model_id("non-existent-model")
print(" ❌ Should have raised an error")
except Exception as e:
print(f" ✓ Correctly raised error: {type(e).__name__}")
print("\n✅ Error handling tests passed!")
return True
except Exception as e:
print(f"\n❌ Test setup failed with error: {e}")
return False
if __name__ == "__main__":
print("mistral.rs Multi-Model Test Suite")
print("=" * 60)
all_passed = True
# Run tests
all_passed &= test_multi_model_operations()
all_passed &= test_model_id_in_requests()
all_passed &= test_unload_reload()
all_passed &= test_error_handling()
print("\n" + "=" * 60)
if all_passed:
print("✅ All tests passed!")
sys.exit(0)
else:
print("❌ Some tests failed!")
sys.exit(1)