Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 105 additions & 25 deletions libs/aws/langchain_aws/chat_models/bedrock_converse.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
MIME_TO_FORMAT = {
# Image formats
"image/png": "png",
"image/jpeg": "jpeg",
"image/jpeg": "jpeg",
"image/gif": "gif",
"image/webp": "webp",
# File formats
Expand Down Expand Up @@ -465,9 +465,9 @@ class Joke(BaseModel):
additionalModelResponseFieldPaths.
"""

supports_tool_choice_values: Optional[
Sequence[Literal["auto", "any", "tool"]]
] = None
supports_tool_choice_values: Optional[Sequence[Literal["auto", "any", "tool"]]] = (
None
)
"""Which types of tool_choice values the model supports.

Inferred if not specified. Inferred as ('auto', 'any', 'tool') if a 'claude-3'
Expand Down Expand Up @@ -512,6 +512,73 @@ def create_cache_point(cls, cache_type: str = "default") -> Dict[str, Any]:
"""
return {"cachePoint": {"type": cache_type}}

@classmethod
def create_document(
cls,
name: str,
source: dict[str, Any],
context: Optional[str] = None,
enable_citations: Optional[bool] = False,
format: Optional[
Literal["pdf", "csv", "doc", "docx", "xls", "xlsx", "html", "txt", "md"]
] = None,
) -> Dict[str, Any]:
"""Create a document configuration for Bedrock.
Args:
name: The name of the document.
source: The source of the document.
context: Info for the model to understand the document for citations.
format: The format of the document, or its extension.
Comment on lines 528 to 531
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add an enable_citations docstring for completeness

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will do!

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed in 4812a05

Returns:
Dictionary containing a properly formatted to add to message content."""
if re.match(r"[^\w\[\]\(\)-]|[\s]{2,}", name):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

re.match won't work here for all cases here, as this will only matches for specific invalid chars/sequences starting from the first character (for example, No Cite won't be caught).

You should use re.search instead (+ simplify a bit):

Suggested change
if re.match(r"[^\w\[\]\(\)-]|[\s]{2,}", name):
if not re.search(r"[^A-Za-z0-9 \[\]()\-]|\s{2,}", name):

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think not was accidentally added to the front. Resolved in d757851

raise ValueError(
"Name must be only alphanumeric characters,"
" whitespace characters (no more than one in a row),"
" hyphens, parantheses, or square brackets."
)

valid_source_types = ["bytes", "content", "s3Location", "text"]
if (
len(source.keys()) > 1
or list(source.keys())[0] not in valid_source_types
):
raise ValueError(
f"The key for source can only be one of the following: {valid_source_types}"
)

if source.get("bytes") and not isinstance(source.get("bytes"), bytes):
raise ValueError(f"Document source with type bytes must be bytes type.")

if source.get("text") and not isinstance(source.get("text"), str):
raise ValueError("Document source with type text must be str type.")

if source.get("s3Location") and not isinstance(
source.get("s3Location").get("uri"), str
):
raise ValueError(
"Document source with type s3Location"
" must have a dictionary with a valid s3 uri as a dict."
)

if source.get("content") and not isinstance(source.get("content", list)):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Passing content source currently fails because isinstance is missing the second argument for type here

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

whoops good catch thanks!

raise ValueError(
"Document source with type content must have a list of document content blocks."
)

document = {"name": name, "source": source}

if context:
document["context"] = context

if format:
document["format"] = format

if enable_citations:
document["citations"] = {"enabled": True}

return {"document": document}

@model_validator(mode="before")
@classmethod
def build_extra(cls, values: dict[str, Any]) -> Any:
Expand All @@ -533,9 +600,11 @@ def build_extra(cls, values: dict[str, Any]) -> Any:
return values

@classmethod
def _get_streaming_support(cls, provider: str, model_id_lower: str) -> Union[bool, str]:
def _get_streaming_support(
cls, provider: str, model_id_lower: str
) -> Union[bool, str]:
"""Determine streaming support for a given provider and model.

Returns:
True: Full streaming support
"no_tools": Streaming supported but not with tools
Expand Down Expand Up @@ -612,7 +681,7 @@ def _get_streaming_support(cls, provider: str, model_id_lower: str) -> Union[boo
@classmethod
def set_disable_streaming(cls, values: Dict) -> Any:
model_id = values.get("model_id", values.get("model"))

# Extract provider from the model_id
# (e.g., "amazon", "anthropic", "ai21", "meta", "mistral")
if "provider" not in values or values["provider"] == "":
Expand Down Expand Up @@ -652,8 +721,8 @@ def set_disable_streaming(cls, values: Dict) -> Any:
@model_validator(mode="after")
def validate_environment(self) -> Self:
"""Validate that AWS credentials to and python package exists in environment."""
# Create bedrock client for control plane API call

# Create bedrock client for control plane API call
if self.bedrock_client is None:
self.bedrock_client = create_aws_client(
region_name=self.region_name,
Expand All @@ -665,7 +734,7 @@ def validate_environment(self) -> Self:
config=self.config,
service_name="bedrock",
)

# Handle streaming configuration for application inference profiles
if "application-inference-profile" in self.model_id:
self._configure_streaming_for_resolved_model()
Expand Down Expand Up @@ -712,27 +781,30 @@ def validate_environment(self) -> Self:
"Provide a guardrail via `guardrail_config` or "
"disable `guard_last_turn_only`."
)

return self

def _get_base_model(self) -> str:
# identify the base model id used in the application inference profile (AIP)
# Format: arn:aws:bedrock:us-east-1:<accountId>:application-inference-profile/<id>
if self.base_model_id is None and 'application-inference-profile' in self.model_id:
if (
self.base_model_id is None
and "application-inference-profile" in self.model_id
):
response = self.bedrock_client.get_inference_profile(
inferenceProfileIdentifier=self.model_id
)
if 'models' in response and len(response['models']) > 0:
model_arn = response['models'][0]['modelArn']
if "models" in response and len(response["models"]) > 0:
model_arn = response["models"][0]["modelArn"]
# Format: arn:aws:bedrock:region::foundation-model/provider.model-name
self.base_model_id = model_arn.split('/')[-1]
self.base_model_id = model_arn.split("/")[-1]
return self.base_model_id if self.base_model_id else self.model_id

def _configure_streaming_for_resolved_model(self) -> None:
"""Configure streaming support after resolving the base model for application inference profiles."""
base_model = self._get_base_model()
model_id_lower = base_model.lower()

streaming_support = self._get_streaming_support(self.provider, model_id_lower)

# Set the disable_streaming flag accordingly
Expand Down Expand Up @@ -1194,7 +1266,7 @@ def _parse_stream_event(event: Dict[str, Any]) -> Optional[BaseMessageChunk]:
)
# always keep block inside a list to preserve merging compatibility
content = [block]

return AIMessageChunk(content=content, tool_call_chunks=tool_call_chunks)
elif "contentBlockDelta" in event:
block = {
Expand All @@ -1213,7 +1285,7 @@ def _parse_stream_event(event: Dict[str, Any]) -> Optional[BaseMessageChunk]:
)
# always keep block inside a list to preserve merging compatibility
content = [block]

return AIMessageChunk(content=content, tool_call_chunks=tool_call_chunks)
elif "contentBlockStop" in event:
# TODO: needed?
Expand Down Expand Up @@ -1244,13 +1316,13 @@ def _mime_type_to_format(mime_type: str) -> str:

if mime_type in MIME_TO_FORMAT:
return MIME_TO_FORMAT[mime_type]

# Fallback to original method of splitting on "/" for simple cases
all_formats = set(MIME_TO_FORMAT.values())
format_part = mime_type.split("/")[1]
if format_part in all_formats:
return format_part

raise ValueError(
f"Unsupported MIME type: {mime_type}. Please refer to the Bedrock Converse API documentation for supported formats."
)
Expand Down Expand Up @@ -1327,7 +1399,9 @@ def _lc_content_to_bedrock(
):
bedrock_content.append(_format_data_content_block(block))
elif block["type"] == "text":
if not block["text"] or (isinstance(block["text"], str) and block["text"].isspace()):
if not block["text"] or (
isinstance(block["text"], str) and block["text"].isspace()
):
bedrock_content.append({"text": "."})
else:
bedrock_content.append({"text": block["text"]})
Expand All @@ -1339,7 +1413,9 @@ def _lc_content_to_bedrock(
bedrock_content.append(
{
"image": {
"format": _mime_type_to_format(block["source"]["mediaType"]),
"format": _mime_type_to_format(
block["source"]["mediaType"]
),
"source": {
"bytes": _b64str_to_bytes(block["source"]["data"])
},
Expand All @@ -1360,7 +1436,9 @@ def _lc_content_to_bedrock(
bedrock_content.append(
{
"video": {
"format": _mime_type_to_format(block["source"]["mediaType"]),
"format": _mime_type_to_format(
block["source"]["mediaType"]
),
"source": {
"bytes": _b64str_to_bytes(block["source"]["data"])
},
Expand All @@ -1371,7 +1449,9 @@ def _lc_content_to_bedrock(
bedrock_content.append(
{
"video": {
"format": _mime_type_to_format(block["source"]["mediaType"]),
"format": _mime_type_to_format(
block["source"]["mediaType"]
),
"source": {"s3Location": block["source"]["data"]},
}
}
Expand Down
Loading