Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion examples/hello_world.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@

async def ask_and_print(question: str, llm: LLM, system_prompt) -> str:
logger.info(f"Q: {question}")
rsp = await llm.aask(question, system_msgs=[system_prompt])
rsp = await llm.aask(question, system_msgs=[system_prompt], stream=True)
if llm.reasoning_content:
logger.info(f"A reasoning: {llm.reasoning_content}")
logger.info(f"A: {rsp}")
return rsp

Expand Down
5 changes: 5 additions & 0 deletions metagpt/configs/llm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class LLMType(Enum):
DEEPSEEK = "deepseek"
SILICONFLOW = "siliconflow"
OPENROUTER = "openrouter"
OPENROUTER_REASONING = "openrouter_reasoning"
BEDROCK = "bedrock"
ARK = "ark" # https://www.volcengine.com/docs/82379/1263482#python-sdk

Expand Down Expand Up @@ -107,6 +108,10 @@ class LLMConfig(YamlModel):
# For Messages Control
use_system_prompt: bool = True

# reasoning / thinking switch
reasoning: bool = False
reasoning_max_token: int = 4000 # reasoning budget tokens to generate, usually smaller than max_token

@field_validator("api_key")
@classmethod
def check_llm_key(cls, v):
Expand Down
2 changes: 2 additions & 0 deletions metagpt/provider/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from metagpt.provider.anthropic_api import AnthropicLLM
from metagpt.provider.bedrock_api import BedrockLLM
from metagpt.provider.ark_api import ArkLLM
from metagpt.provider.openrouter_reasoning import OpenrouterReasoningLLM

__all__ = [
"GeminiLLM",
Expand All @@ -34,4 +35,5 @@
"AnthropicLLM",
"BedrockLLM",
"ArkLLM",
"OpenrouterReasoningLLM",
]
22 changes: 18 additions & 4 deletions metagpt/provider/anthropic_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,21 @@ def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict:
if messages[0]["role"] == "system":
kwargs["messages"] = messages[1:]
kwargs["system"] = messages[0]["content"] # set system prompt here
if self.config.reasoning:
kwargs["thinking"] = {"type": "enabled", "budget_tokens": self.config.reasoning_max_token}
return kwargs

def _update_costs(self, usage: Usage, model: str = None, local_calc_usage: bool = True):
usage = {"prompt_tokens": usage.input_tokens, "completion_tokens": usage.output_tokens}
super()._update_costs(usage, model)

def get_choice_text(self, resp: Message) -> str:
return resp.content[0].text
if len(resp.content) > 1:
self.reasoning_content = resp.content[0].thinking
text = resp.content[1].text
else:
text = resp.content[0].text
return text

async def _achat_completion(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> Message:
resp: Message = await self.aclient.messages.create(**self._const_kwargs(messages))
Expand All @@ -53,20 +60,27 @@ async def acompletion(self, messages: list[dict], timeout: int = USE_CONFIG_TIME
async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str:
stream = await self.aclient.messages.create(**self._const_kwargs(messages, stream=True))
collected_content = []
collected_reasoning_content = []
usage = Usage(input_tokens=0, output_tokens=0)
async for event in stream:
event_type = event.type
if event_type == "message_start":
usage.input_tokens = event.message.usage.input_tokens
usage.output_tokens = event.message.usage.output_tokens
elif event_type == "content_block_delta":
content = event.delta.text
log_llm_stream(content)
collected_content.append(content)
delta_type = event.delta.type
if delta_type == "thinking_delta":
collected_reasoning_content.append(event.delta.thinking)
elif delta_type == "text_delta":
content = event.delta.text
log_llm_stream(content)
collected_content.append(content)
elif event_type == "message_delta":
usage.output_tokens = event.usage.output_tokens # update final output_tokens

log_llm_stream("\n")
self._update_costs(usage)
full_content = "".join(collected_content)
if collected_reasoning_content:
self.reasoning_content = "".join(collected_reasoning_content)
return full_content
15 changes: 14 additions & 1 deletion metagpt/provider/base_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,16 @@ class BaseLLM(ABC):
model: Optional[str] = None # deprecated
pricing_plan: Optional[str] = None

_reasoning_content: Optional[str] = None # content from reasoning mode

@property
def reasoning_content(self):
return self._reasoning_content

@reasoning_content.setter
def reasoning_content(self, value: str):
self._reasoning_content = value

@abstractmethod
def __init__(self, config: LLMConfig):
pass
Expand Down Expand Up @@ -216,7 +226,10 @@ async def acompletion_text(

def get_choice_text(self, rsp: dict) -> str:
"""Required to provide the first text of choice"""
return rsp.get("choices")[0]["message"]["content"]
message = rsp.get("choices")[0]["message"]
if "reasoning_content" in message:
self.reasoning_content = message["reasoning_content"]
return message["content"]

def get_choice_delta_text(self, rsp: dict) -> str:
"""Required to provide the first text of stream choice"""
Expand Down
11 changes: 8 additions & 3 deletions metagpt/provider/bedrock/base_provider.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
import json
from abc import ABC, abstractmethod
from typing import Union


class BaseBedrockProvider(ABC):
# to handle different generation kwargs
max_tokens_field_name = "max_tokens"

def __init__(self, reasoning: bool = False, reasoning_max_token: int = 4000):
self.reasoning = reasoning
self.reasoning_max_token = reasoning_max_token

@abstractmethod
def _get_completion_from_dict(self, rsp_dict: dict) -> str:
...
Expand All @@ -14,14 +19,14 @@ def get_request_body(self, messages: list[dict], const_kwargs, *args, **kwargs)
body = json.dumps({"prompt": self.messages_to_prompt(messages), **const_kwargs})
return body

def get_choice_text(self, response_body: dict) -> str:
def get_choice_text(self, response_body: dict) -> Union[str, dict[str, str]]:
completions = self._get_completion_from_dict(response_body)
return completions

def get_choice_text_from_stream(self, event) -> str:
def get_choice_text_from_stream(self, event) -> Union[bool, str]:
rsp_dict = json.loads(event["chunk"]["bytes"])
completions = self._get_completion_from_dict(rsp_dict)
return completions
return False, completions

def messages_to_prompt(self, messages: list[dict]) -> str:
"""[{"role": "user", "content": msg}] to user: <msg> etc."""
Expand Down
52 changes: 37 additions & 15 deletions metagpt/provider/bedrock/bedrock_provider.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from typing import Literal, Tuple
from typing import Literal, Tuple, Union

from metagpt.provider.bedrock.base_provider import BaseBedrockProvider
from metagpt.provider.bedrock.utils import (
Expand All @@ -20,6 +20,8 @@ def _get_completion_from_dict(self, rsp_dict: dict) -> str:

class AnthropicProvider(BaseBedrockProvider):
# See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html
# https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-37.html
# https://docs.aws.amazon.com/code-library/latest/ug/python_3_bedrock-runtime_code_examples.html#anthropic_claude

def _split_system_user_messages(self, messages: list[dict]) -> Tuple[str, list[dict]]:
system_messages = []
Expand All @@ -32,6 +34,10 @@ def _split_system_user_messages(self, messages: list[dict]) -> Tuple[str, list[d
return self.messages_to_prompt(system_messages), user_messages

def get_request_body(self, messages: list[dict], generate_kwargs, *args, **kwargs) -> str:
if self.reasoning:
generate_kwargs["temperature"] = 1 # should be 1
generate_kwargs["thinking"] = {"type": "enabled", "budget_tokens": self.reasoning_max_token}

system_message, user_messages = self._split_system_user_messages(messages)
body = json.dumps(
{
Expand All @@ -43,17 +49,27 @@ def get_request_body(self, messages: list[dict], generate_kwargs, *args, **kwarg
)
return body

def _get_completion_from_dict(self, rsp_dict: dict) -> str:
def _get_completion_from_dict(self, rsp_dict: dict) -> dict[str, Tuple[str, str]]:
if self.reasoning:
return {"reasoning_content": rsp_dict["content"][0]["thinking"], "content": rsp_dict["content"][1]["text"]}
return rsp_dict["content"][0]["text"]

def get_choice_text_from_stream(self, event) -> str:
def get_choice_text_from_stream(self, event) -> Union[bool, str]:
# https://docs.anthropic.com/claude/reference/messages-streaming
rsp_dict = json.loads(event["chunk"]["bytes"])
if rsp_dict["type"] == "content_block_delta":
completions = rsp_dict["delta"]["text"]
return completions
reasoning = False
delta_type = rsp_dict["delta"]["type"]
if delta_type == "text_delta":
completions = rsp_dict["delta"]["text"]
elif delta_type == "thinking_delta":
completions = rsp_dict["delta"]["thinking"]
reasoning = True
elif delta_type == "signature_delta":
completions = ""
return reasoning, completions
else:
return ""
return False, ""


class CohereProvider(BaseBedrockProvider):
Expand Down Expand Up @@ -87,10 +103,10 @@ def get_request_body(self, messages: list[dict], generate_kwargs, *args, **kwarg
body = json.dumps({"prompt": prompt, "stream": kwargs.get("stream", False), **generate_kwargs})
return body

def get_choice_text_from_stream(self, event) -> str:
def get_choice_text_from_stream(self, event) -> Union[bool, str]:
rsp_dict = json.loads(event["chunk"]["bytes"])
completions = rsp_dict.get("text", "")
return completions
return False, completions


class MetaProvider(BaseBedrockProvider):
Expand Down Expand Up @@ -133,10 +149,10 @@ def get_request_body(self, messages: list[dict], generate_kwargs, *args, **kwarg
)
return body

def get_choice_text_from_stream(self, event) -> str:
def get_choice_text_from_stream(self, event) -> Union[bool, str]:
rsp_dict = json.loads(event["chunk"]["bytes"])
completions = rsp_dict.get("choices", [{}])[0].get("delta", {}).get("content", "")
return completions
return False, completions

def _get_completion_from_dict(self, rsp_dict: dict) -> str:
if self.model_type == "j2":
Expand All @@ -159,10 +175,10 @@ def get_request_body(self, messages: list[dict], generate_kwargs, *args, **kwarg
def _get_completion_from_dict(self, rsp_dict: dict) -> str:
return rsp_dict["results"][0]["outputText"]

def get_choice_text_from_stream(self, event) -> str:
def get_choice_text_from_stream(self, event) -> Union[bool, str]:
rsp_dict = json.loads(event["chunk"]["bytes"])
completions = rsp_dict["outputText"]
return completions
return False, completions


PROVIDERS = {
Expand All @@ -175,8 +191,14 @@ def get_choice_text_from_stream(self, event) -> str:
}


def get_provider(model_id: str):
provider, model_name = model_id.split(".")[0:2] # meta、mistral……
def get_provider(model_id: str, reasoning: bool = False, reasoning_max_token: int = 4000):
arr = model_id.split(".")
if len(arr) == 2:
provider, model_name = arr # meta、mistral……
elif len(arr) == 3:
# some model_ids may contain country like us.xx.xxx
_, provider, model_name = arr

if provider not in PROVIDERS:
raise KeyError(f"{provider} is not supported!")
if provider == "meta":
Expand All @@ -188,4 +210,4 @@ def get_provider(model_id: str):
elif provider == "cohere":
# distinguish between R/R+ and older models
return PROVIDERS[provider](model_name)
return PROVIDERS[provider]()
return PROVIDERS[provider](reasoning=reasoning, reasoning_max_token=reasoning_max_token)
17 changes: 3 additions & 14 deletions metagpt/provider/bedrock/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@
"anthropic.claude-3-opus-20240229-v1:0": 4096,
# Claude 3.5 Sonnet
"anthropic.claude-3-5-sonnet-20240620-v1:0": 8192,
# Claude 3.7 Sonnet
"us.anthropic.claude-3-7-sonnet-20250219-v1:0": 131072,
"anthropic.claude-3-7-sonnet-20250219-v1:0": 131072,
# Command Text
# https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-cohere-command.html
"cohere.command-text-v14": 4096,
Expand Down Expand Up @@ -135,20 +138,6 @@ def messages_to_prompt_llama3(messages: list[dict]) -> str:
return prompt


def messages_to_prompt_claude2(messages: list[dict]) -> str:
GENERAL_TEMPLATE = "\n\n{role}: {content}"
prompt = ""
for message in messages:
role = message.get("role", "")
content = message.get("content", "")
prompt += GENERAL_TEMPLATE.format(role=role, content=content)

if role != "assistant":
prompt += "\n\nAssistant:"

return prompt


def get_max_tokens(model_id: str) -> int:
try:
max_tokens = (NOT_SUPPORT_STREAM_MODELS | SUPPORT_STREAM_MODELS)[model_id]
Expand Down
22 changes: 17 additions & 5 deletions metagpt/provider/bedrock_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ class BedrockLLM(BaseLLM):
def __init__(self, config: LLMConfig):
self.config = config
self.__client = self.__init_client("bedrock-runtime")
self.__provider = get_provider(self.config.model)
self.__provider = get_provider(
self.config.model, reasoning=self.config.reasoning, reasoning_max_token=self.config.reasoning_max_token
)
self.cost_manager = CostManager(token_costs=BEDROCK_TOKEN_COSTS)
if self.config.model in NOT_SUPPORT_STREAM_MODELS:
logger.warning(f"model {self.config.model} doesn't support streaming output!")
Expand Down Expand Up @@ -102,7 +104,11 @@ def _const_kwargs(self) -> dict:
# However,aioboto3 doesn't support invoke model

def get_choice_text(self, rsp: dict) -> str:
return self.__provider.get_choice_text(rsp)
rsp = self.__provider.get_choice_text(rsp)
if isinstance(rsp, dict):
self.reasoning_content = rsp.get("reasoning_content")
rsp = rsp.get("content")
return rsp

async def acompletion(self, messages: list[dict]) -> dict:
request_body = self.__provider.get_request_body(messages, self._const_kwargs)
Expand Down Expand Up @@ -133,10 +139,16 @@ def _get_response_body(self, response) -> dict:
async def _get_stream_response_body(self, stream_response) -> List[str]:
def collect_content() -> str:
collected_content = []
collected_reasoning_content = []
for event in stream_response["body"]:
chunk_text = self.__provider.get_choice_text_from_stream(event)
collected_content.append(chunk_text)
log_llm_stream(chunk_text)
reasoning, chunk_text = self.__provider.get_choice_text_from_stream(event)
if reasoning:
collected_reasoning_content.append(chunk_text)
else:
collected_content.append(chunk_text)
log_llm_stream(chunk_text)
if collected_reasoning_content:
self.reasoning_content = "".join(collected_reasoning_content)
return collected_content

loop = asyncio.get_running_loop()
Expand Down
15 changes: 8 additions & 7 deletions metagpt/provider/general_api_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,14 @@ def response_ms(self) -> Optional[int]:
h = self._headers.get("Openai-Processing-Ms")
return None if h is None else round(float(h))

def decode_asjson(self) -> Optional[dict]:
bstr = self.data.strip()
if bstr.startswith(b"{") and bstr.endswith(b"}"):
bstr = bstr.decode("utf-8")
else:
bstr = parse_stream_helper(bstr)
return json.loads(bstr) if bstr else None


def _build_api_url(url, query):
scheme, netloc, path, base_query, fragment = urlsplit(url)
Expand Down Expand Up @@ -547,13 +555,6 @@ async def arequest_raw(
}
try:
result = await session.request(**request_kwargs)
# log_info(
# "LLM API response",
# path=abs_url,
# response_code=result.status,
# processing_ms=result.headers.get("LLM-Processing-Ms"),
# request_id=result.headers.get("X-Request-Id"),
# )
return result
except (aiohttp.ServerTimeoutError, asyncio.TimeoutError) as e:
raise openai.APITimeoutError("Request timed out") from e
Expand Down
Loading
Loading