Skip to content

Commit a855e66

Browse files
authored
Merge pull request #1848 from better629/main
update test_bedrock_api
2 parents ba28689 + 38d6d81 commit a855e66

File tree

2 files changed

+83
-10
lines changed

2 files changed

+83
-10
lines changed

tests/metagpt/provider/req_resp_const.py

Lines changed: 52 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ async def llm_general_chat_funcs_test(llm: BaseLLM, prompt: str, messages: list[
191191
BEDROCK_PROVIDER_REQUEST_BODY = {
192192
"mistral": {"prompt": "", "max_tokens": 0, "stop": [], "temperature": 0.0, "top_p": 0.0, "top_k": 0},
193193
"meta": {"prompt": "", "temperature": 0.0, "top_p": 0.0, "max_gen_len": 0},
194-
"ai21": {
194+
"ai21-j2": {
195195
"prompt": "",
196196
"temperature": 0.0,
197197
"topP": 0.0,
@@ -201,6 +201,16 @@ async def llm_general_chat_funcs_test(llm: BaseLLM, prompt: str, messages: list[
201201
"presencePenalty": {"scale": 0.0},
202202
"frequencyPenalty": {"scale": 0.0},
203203
},
204+
"ai21-jamba": {
205+
"messages": [],
206+
"temperature": 0.0,
207+
"topP": 0.0,
208+
"max_tokens": 0,
209+
"stopSequences": [],
210+
"countPenalty": {"scale": 0.0},
211+
"presencePenalty": {"scale": 0.0},
212+
"frequencyPenalty": {"scale": 0.0},
213+
},
204214
"cohere": {
205215
"prompt": "",
206216
"temperature": 0.0,
@@ -214,6 +224,20 @@ async def llm_general_chat_funcs_test(llm: BaseLLM, prompt: str, messages: list[
214224
"logit_bias": {},
215225
"truncate": "NONE",
216226
},
227+
"cohere-command-r": {
228+
"message": [],
229+
"chat_history": [],
230+
"temperature": 0.0,
231+
"p": 0.0,
232+
"k": 0.0,
233+
"max_tokens": 0,
234+
"stop_sequences": [],
235+
"return_likelihoods": "NONE",
236+
"stream": False,
237+
"num_generations": 0,
238+
"logit_bias": {},
239+
"truncate": "NONE",
240+
},
217241
"anthropic": {
218242
"anthropic_version": "bedrock-2023-05-31",
219243
"max_tokens": 0,
@@ -233,12 +257,20 @@ async def llm_general_chat_funcs_test(llm: BaseLLM, prompt: str, messages: list[
233257
BEDROCK_PROVIDER_RESPONSE_BODY = {
234258
"mistral": {"outputs": [{"text": "Hello World", "stop_reason": ""}]},
235259
"meta": {"generation": "Hello World", "prompt_token_count": 0, "generation_token_count": 0, "stop_reason": ""},
236-
"ai21": {
260+
"ai21-jamba": {
237261
"id": "",
238262
"prompt": {"text": "Hello World", "tokens": []},
239-
"completions": [
240-
{"data": {"text": "Hello World", "tokens": []}, "finishReason": {"reason": "length", "length": 2}}
241-
],
263+
"choices": [{"message": {"content": "Hello World"}}],
264+
},
265+
"ai21-jamba-stream": {
266+
"id": "",
267+
"prompt": {"text": "Hello World", "tokens": []},
268+
"choices": [{"delta": {"content": "Hello World"}}],
269+
},
270+
"ai21-j2": {
271+
"id": "",
272+
"prompt": {"text": "Hello World", "tokens": []},
273+
"completions": [{"data": {"text": "Hello World"}, "finishReason": {"reason": "length", "length": 2}}],
242274
},
243275
"cohere": {
244276
"generations": [
@@ -255,6 +287,21 @@ async def llm_general_chat_funcs_test(llm: BaseLLM, prompt: str, messages: list[
255287
"id": "",
256288
"prompt": "",
257289
},
290+
"cohere-command-r": {
291+
"generations": [
292+
{
293+
"finish_reason": "",
294+
"id": "",
295+
"text": "Hello World",
296+
"likelihood": 0.0,
297+
"token_likelihoods": [{"token": 0.0}],
298+
"is_finished": True,
299+
"index": 0,
300+
}
301+
],
302+
"id": "",
303+
"prompt": "",
304+
},
258305
"anthropic": {
259306
"id": "",
260307
"model": "",

tests/metagpt/provider/test_bedrock_api.py

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,18 +22,42 @@
2222
}
2323

2424

25-
def mock_invoke_model(self: BedrockLLM, *args, **kwargs) -> dict:
26-
provider = self.config.model.split(".")[0]
25+
def get_provider_name(model: str) -> str:
26+
arr = model.split(".")
27+
if len(arr) == 2:
28+
provider, model_name = arr # meta、mistral……
29+
elif len(arr) == 3:
30+
# some model_ids may contain country like us.xx.xxx
31+
_, provider, model_name = arr
32+
return provider
33+
34+
35+
def deal_special_provider(provider: str, model: str, stream: bool = False) -> str:
36+
# for ai21
37+
if "j2-" in model:
38+
provider = f"{provider}-j2"
39+
elif "jamba-" in model:
40+
provider = f"{provider}-jamba"
41+
elif "command-r" in model:
42+
provider = f"{provider}-command-r"
43+
if stream and "ai21" in model:
44+
provider = f"{provider}-stream"
45+
return provider
46+
47+
48+
async def mock_invoke_model(self: BedrockLLM, *args, **kwargs) -> dict:
49+
provider = get_provider_name(self.config.model)
2750
self._update_costs(usage, self.config.model)
51+
provider = deal_special_provider(provider, self.config.model)
2852
return BEDROCK_PROVIDER_RESPONSE_BODY[provider]
2953

3054

31-
def mock_invoke_model_stream(self: BedrockLLM, *args, **kwargs) -> dict:
55+
async def mock_invoke_model_stream(self: BedrockLLM, *args, **kwargs) -> dict:
3256
# use json object to mock EventStream
3357
def dict2bytes(x):
3458
return json.dumps(x).encode("utf-8")
3559

36-
provider = self.config.model.split(".")[0]
60+
provider = get_provider_name(self.config.model)
3761

3862
if provider == "amazon":
3963
response_body_bytes = dict2bytes({"outputText": "Hello World"})
@@ -44,6 +68,7 @@ def dict2bytes(x):
4468
elif provider == "cohere":
4569
response_body_bytes = dict2bytes({"is_finished": False, "text": "Hello World"})
4670
else:
71+
provider = deal_special_provider(provider, self.config.model, stream=True)
4772
response_body_bytes = dict2bytes(BEDROCK_PROVIDER_RESPONSE_BODY[provider])
4873

4974
response_body_stream = {"body": [{"chunk": {"bytes": response_body_bytes}}]}
@@ -52,7 +77,8 @@ def dict2bytes(x):
5277

5378

5479
def get_bedrock_request_body(model_id) -> dict:
55-
provider = model_id.split(".")[0]
80+
provider = get_provider_name(model_id)
81+
provider = deal_special_provider(provider, model_id)
5682
return BEDROCK_PROVIDER_REQUEST_BODY[provider]
5783

5884

0 commit comments

Comments
 (0)