diff --git a/src/emd/cli.py b/src/emd/cli.py index 7b6fcbf8..ed7421e2 100644 --- a/src/emd/cli.py +++ b/src/emd/cli.py @@ -83,11 +83,17 @@ @app.command(help="List supported models") @catch_aws_credential_errors -def list_supported_models(model_id: Annotated[ +def list_supported_models( + model_id: Annotated[ str, typer.Argument(help="Model ID") - ] = None): + ] = None, + detail: Annotated[ + Optional[bool], + typer.Option("-a", "--detail", help="output model information in details.") + ] = False +): # console.print("[bold blue]Retrieving models...[/bold blue]") - support_models = Model.get_supported_models() + support_models = Model.get_supported_models(detail=detail) if model_id: support_models = [model for _model_id,model in support_models.items() if _model_id == model_id] r = json.dumps(support_models,indent=2,ensure_ascii=False) diff --git a/src/emd/models/llms/txgemma.py b/src/emd/models/llms/txgemma.py index 0c4e29ff..06aa7e6f 100644 --- a/src/emd/models/llms/txgemma.py +++ b/src/emd/models/llms/txgemma.py @@ -47,6 +47,7 @@ supported_frameworks=[ fastapi_framework ], + allow_china_region=True, huggingface_model_id="google/txgemma-9b-chat", modelscope_model_id="AI-ModelScope/txgemma-9b-chat", model_files_download_source=ModelFilesDownloadSource.MODELSCOPE, @@ -79,6 +80,7 @@ supported_frameworks=[ fastapi_framework ], + allow_china_region=True, huggingface_model_id="google/txgemma-27b-chat", modelscope_model_id="AI-ModelScope/txgemma-27b-chat", model_files_download_source=ModelFilesDownloadSource.MODELSCOPE, diff --git a/src/emd/models/model.py b/src/emd/models/model.py index 1e052ef1..6289e890 100644 --- a/src/emd/models/model.py +++ b/src/emd/models/model.py @@ -210,8 +210,10 @@ def get_model(cls ,model_id:str,update:dict = None) -> T: return model @classmethod - def get_supported_models(cls) -> dict: - return {model_id: model.model_type for model_id,model in cls.model_map.items()} + def get_supported_models(cls,detail=False) -> dict: + if not detail: + return {model_id: model.model_type for model_id,model in cls.model_map.items()} + return {model_id: model.model_dump() for model_id,model in cls.model_map.items()} def find_current_engine(self,engine_type:str) -> dict: supported_engines:List[Engine] = self.supported_engines