3838 Usage ,
3939 UserMessage ,
4040)
41- from api .setting import AWS_REGION , DEBUG , DEFAULT_MODEL , ENABLE_CROSS_REGION_INFERENCE
41+ from api .setting import (
42+ AWS_REGION ,
43+ DEBUG ,
44+ DEFAULT_MODEL ,
45+ ENABLE_CROSS_REGION_INFERENCE ,
46+ ENABLE_APPLICATION_INFERENCE_PROFILES ,
47+ )
4248
4349logger = logging .getLogger (__name__ )
4450
@@ -83,15 +89,40 @@ def list_bedrock_models() -> dict:
8389 Returns a model list combines:
8490 - ON_DEMAND models.
8591 - Cross-Region Inference Profiles (if enabled via Env)
92+ - Application Inference Profiles (if enabled via Env)
8693 """
8794 model_list = {}
8895 try :
8996 profile_list = []
97+ app_profile_dict = {}
98+
9099 if ENABLE_CROSS_REGION_INFERENCE :
91100 # List system defined inference profile IDs
92101 response = bedrock_client .list_inference_profiles (maxResults = 1000 , typeEquals = "SYSTEM_DEFINED" )
93102 profile_list = [p ["inferenceProfileId" ] for p in response ["inferenceProfileSummaries" ]]
94103
104+ if ENABLE_APPLICATION_INFERENCE_PROFILES :
105+ # List application defined inference profile IDs and create mapping
106+ response = bedrock_client .list_inference_profiles (maxResults = 1000 , typeEquals = "APPLICATION" )
107+
108+ for profile in response ["inferenceProfileSummaries" ]:
109+ try :
110+ profile_arn = profile .get ("inferenceProfileArn" )
111+ if not profile_arn :
112+ continue
113+
114+ # Process all models in the profile
115+ models = profile .get ("models" , [])
116+ for model in models :
117+ model_arn = model .get ("modelArn" , "" )
118+ if model_arn :
119+ model_id = model_arn .split ('/' )[- 1 ] if '/' in model_arn else model_arn
120+ if model_id :
121+ app_profile_dict [model_id ] = profile_arn
122+ except Exception as e :
123+ logger .warning (f"Error processing application profile: { e } " )
124+ continue
125+
95126 # List foundation models, only cares about text outputs here.
96127 response = bedrock_client .list_foundation_models (byOutputModality = "TEXT" )
97128
@@ -115,6 +146,10 @@ def list_bedrock_models() -> dict:
115146 if profile_id in profile_list :
116147 model_list [profile_id ] = {"modalities" : input_modalities }
117148
149+ # Add application inference profiles
150+ if model_id in app_profile_dict :
151+ model_list [app_profile_dict [model_id ]] = {"modalities" : input_modalities }
152+
118153 except Exception as e :
119154 logger .error (f"Unable to list models: { str (e )} " )
120155
@@ -162,7 +197,9 @@ async def _invoke_bedrock(self, chat_request: ChatRequest, stream=False):
162197 try :
163198 if stream :
164199 # Run the blocking boto3 call in a thread pool
165- response = await run_in_threadpool (bedrock_runtime .converse_stream , ** args )
200+ response = await run_in_threadpool (
201+ bedrock_runtime .converse_stream , ** args
202+ )
166203 else :
167204 # Run the blocking boto3 call in a thread pool
168205 response = await run_in_threadpool (bedrock_runtime .converse , ** args )
@@ -274,7 +311,9 @@ def _parse_messages(self, chat_request: ChatRequest) -> list[dict]:
274311 messages .append (
275312 {
276313 "role" : message .role ,
277- "content" : self ._parse_content_parts (message , chat_request .model ),
314+ "content" : self ._parse_content_parts (
315+ message , chat_request .model
316+ ),
278317 }
279318 )
280319 elif isinstance (message , AssistantMessage ):
@@ -283,7 +322,9 @@ def _parse_messages(self, chat_request: ChatRequest) -> list[dict]:
283322 messages .append (
284323 {
285324 "role" : message .role ,
286- "content" : self ._parse_content_parts (message , chat_request .model ),
325+ "content" : self ._parse_content_parts (
326+ message , chat_request .model
327+ ),
287328 }
288329 )
289330 if message .tool_calls :
@@ -363,7 +404,9 @@ def _reframe_multi_payloard(self, messages: list) -> list:
363404 # If the next role is different from the previous message, add the previous role's messages to the list
364405 if next_role != current_role :
365406 if current_content :
366- reformatted_messages .append ({"role" : current_role , "content" : current_content })
407+ reformatted_messages .append (
408+ {"role" : current_role , "content" : current_content }
409+ )
367410 # Switch to the new role
368411 current_role = next_role
369412 current_content = []
@@ -376,7 +419,9 @@ def _reframe_multi_payloard(self, messages: list) -> list:
376419
377420 # Add the last role's messages to the list
378421 if current_content :
379- reformatted_messages .append ({"role" : current_role , "content" : current_content })
422+ reformatted_messages .append (
423+ {"role" : current_role , "content" : current_content }
424+ )
380425
381426 return reformatted_messages
382427
@@ -414,9 +459,13 @@ def _parse_request(self, chat_request: ChatRequest) -> dict:
414459 # Use max_completion_tokens if provided.
415460
416461 max_tokens = (
417- chat_request .max_completion_tokens if chat_request .max_completion_tokens else chat_request .max_tokens
462+ chat_request .max_completion_tokens
463+ if chat_request .max_completion_tokens
464+ else chat_request .max_tokens
465+ )
466+ budget_tokens = self ._calc_budget_tokens (
467+ max_tokens , chat_request .reasoning_effort
418468 )
419- budget_tokens = self ._calc_budget_tokens (max_tokens , chat_request .reasoning_effort )
420469 inference_config ["maxTokens" ] = max_tokens
421470 # unset topP - Not supported
422471 inference_config .pop ("topP" )
@@ -428,7 +477,9 @@ def _parse_request(self, chat_request: ChatRequest) -> dict:
428477 if chat_request .tools :
429478 tool_config = {"tools" : [self ._convert_tool_spec (t .function ) for t in chat_request .tools ]}
430479
431- if chat_request .tool_choice and not chat_request .model .startswith ("meta.llama3-1-" ):
480+ if chat_request .tool_choice and not chat_request .model .startswith (
481+ "meta.llama3-1-"
482+ ):
432483 if isinstance (chat_request .tool_choice , str ):
433484 # auto (default) is mapped to {"auto" : {}}
434485 # required is mapped to {"any" : {}}
@@ -477,11 +528,15 @@ def _create_response(
477528 message .content = ""
478529 for c in content :
479530 if "reasoningContent" in c :
480- message .reasoning_content = c ["reasoningContent" ]["reasoningText" ].get ("text" , "" )
531+ message .reasoning_content = c ["reasoningContent" ][
532+ "reasoningText"
533+ ].get ("text" , "" )
481534 elif "text" in c :
482535 message .content = c ["text" ]
483536 else :
484- logger .warning ("Unknown tag in message content " + "," .join (c .keys ()))
537+ logger .warning (
538+ "Unknown tag in message content " + "," .join (c .keys ())
539+ )
485540
486541 response = ChatResponse (
487542 id = message_id ,
@@ -505,7 +560,9 @@ def _create_response(
505560 response .created = int (time .time ())
506561 return response
507562
508- def _create_response_stream (self , model_id : str , message_id : str , chunk : dict ) -> ChatStreamResponse | None :
563+ def _create_response_stream (
564+ self , model_id : str , message_id : str , chunk : dict
565+ ) -> ChatStreamResponse | None :
509566 """Parsing the Bedrock stream response chunk.
510567
511568 Ref: https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html#message-inference-examples
@@ -627,7 +684,9 @@ def _parse_image(self, image_url: str) -> tuple[bytes, str]:
627684 image_content = response .content
628685 return image_content , content_type
629686 else :
630- raise HTTPException (status_code = 500 , detail = "Unable to access the image url" )
687+ raise HTTPException (
688+ status_code = 500 , detail = "Unable to access the image url"
689+ )
631690
632691 def _parse_content_parts (
633692 self ,
@@ -687,7 +746,9 @@ def _convert_tool_spec(self, func: Function) -> dict:
687746 }
688747 }
689748
690- def _calc_budget_tokens (self , max_tokens : int , reasoning_effort : Literal ["low" , "medium" , "high" ]) -> int :
749+ def _calc_budget_tokens (
750+ self , max_tokens : int , reasoning_effort : Literal ["low" , "medium" , "high" ]
751+ ) -> int :
691752 # Helper function to calculate budget_tokens based on the max_tokens.
692753 # Ratio for efforts: Low - 30%, medium - 60%, High: Max token - 1
693754 # Note that The minimum budget_tokens is 1,024 tokens so far.
@@ -718,7 +779,9 @@ def _convert_finish_reason(self, finish_reason: str | None) -> str | None:
718779 "complete" : "stop" ,
719780 "content_filtered" : "content_filter" ,
720781 }
721- return finish_reason_mapping .get (finish_reason .lower (), finish_reason .lower ())
782+ return finish_reason_mapping .get (
783+ finish_reason .lower (), finish_reason .lower ()
784+ )
722785 return None
723786
724787
@@ -809,7 +872,9 @@ def _parse_args(self, embeddings_request: EmbeddingsRequest) -> dict:
809872 return args
810873
811874 def embed (self , embeddings_request : EmbeddingsRequest ) -> EmbeddingsResponse :
812- response = self ._invoke_model (args = self ._parse_args (embeddings_request ), model_id = embeddings_request .model )
875+ response = self ._invoke_model (
876+ args = self ._parse_args (embeddings_request ), model_id = embeddings_request .model
877+ )
813878 response_body = json .loads (response .get ("body" ).read ())
814879 if DEBUG :
815880 logger .info ("Bedrock response body: " + str (response_body ))
@@ -825,10 +890,15 @@ class TitanEmbeddingsModel(BedrockEmbeddingsModel):
825890 def _parse_args (self , embeddings_request : EmbeddingsRequest ) -> dict :
826891 if isinstance (embeddings_request .input , str ):
827892 input_text = embeddings_request .input
828- elif isinstance (embeddings_request .input , list ) and len (embeddings_request .input ) == 1 :
893+ elif (
894+ isinstance (embeddings_request .input , list )
895+ and len (embeddings_request .input ) == 1
896+ ):
829897 input_text = embeddings_request .input [0 ]
830898 else :
831- raise ValueError ("Amazon Titan Embeddings models support only single strings as input." )
899+ raise ValueError (
900+ "Amazon Titan Embeddings models support only single strings as input."
901+ )
832902 args = {
833903 "inputText" : input_text ,
834904 # Note: inputImage is not supported!
@@ -842,7 +912,9 @@ def _parse_args(self, embeddings_request: EmbeddingsRequest) -> dict:
842912 return args
843913
844914 def embed (self , embeddings_request : EmbeddingsRequest ) -> EmbeddingsResponse :
845- response = self ._invoke_model (args = self ._parse_args (embeddings_request ), model_id = embeddings_request .model )
915+ response = self ._invoke_model (
916+ args = self ._parse_args (embeddings_request ), model_id = embeddings_request .model
917+ )
846918 response_body = json .loads (response .get ("body" ).read ())
847919 if DEBUG :
848920 logger .info ("Bedrock response body: " + str (response_body ))
0 commit comments