Skip to content
This repository was archived by the owner on Sep 20, 2025. It is now read-only.
14 changes: 7 additions & 7 deletions .github/workflows/build-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,6 @@ jobs:
with:
python-version: '3.x' # Specify your Python version

- name: Install Poetry
run: |
curl -sSL https://install.python-poetry.org | python3

- name: Build wheel
run: poetry build -f wheel

- name: Extract version
id: get_version
run: |
Expand All @@ -39,6 +32,13 @@ jobs:
sed -i "s/COMMIT_HASH = \".*\"/COMMIT_HASH = \"$COMMIT_HASH\"/" src/emd/revision.py
echo "SHORT_SHA=$COMMIT_HASH" >> $GITHUB_ENV

- name: Install Poetry
run: |
curl -sSL https://install.python-poetry.org | python3

- name: Build wheel
run: poetry build -f wheel

- name: Upload wheel artifact
uses: actions/upload-artifact@v4
with:
Expand Down
24 changes: 15 additions & 9 deletions src/emd/cfn/codepipeline/template.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
AWSTemplateFormatVersion: '2010-09-09'
Description: CodePipeline for model deployment
Description: |
Easy Model Deployer bootstrap environment.
If you delete this stack, you will not be able to deploy any new models.

Parameters:
ArtifactBucketName:
Type: String
Expand Down Expand Up @@ -195,7 +198,7 @@ Resources:
phases:
pre_build:
commands:
- echo Build started on `date`
- echo model build pipeline started on `date`
build:
commands:
- |-
Expand All @@ -214,21 +217,24 @@ Resources:
pip install --upgrade pip
pip install -r requirements.txt
python pipeline.py --region $region --model_id $model_id --model_tag $ModelTag --framework_type $FrameworkType --service_type $service --backend_type $backend_name --model_s3_bucket $model_s3_bucket --instance_type $instance_type --extra_params "$extra_params" --skip_deploy
cd ..
echo pipeline build completed on `date`
# cd ..
echo model build pipeline completed on `date`

post_build:
commands:
- |-
echo post build started on `date`
SERVICE_TYPE=$(echo "$ServiceType" | tr '[:upper:]' '[:lower:]')
if [ -f ../cfn/$ServiceType/post_build.py ]; then
# copy post_build.py to pipeline so that the post_build.py can use the same module
cp ../cfn/$ServiceType/post_build.py $ServiceType_post_build.py
python $ServiceType_post_build.py --region $region --model_id $model_id --model_tag $ModelTag --framework_type $FrameworkType --service_type $service --backend_type $backend_name --model_s3_bucket $model_s3_bucket --instance_type $instance_type --extra_params "$extra_params"
fi
cd ..
cp cfn/$ServiceType/template.yaml template.yaml
cp pipeline/parameters.json parameters.json
if [ -f cfn/$ServiceType/post_build.py ]; then
cp cfn/$ServiceType/post_build.py post_build.py
python post_build.py --region $region --model_id $model_id --model_tag $ModelTag --framework_type $FrameworkType --service_type $service --backend_type $backend_name --model_s3_bucket $model_s3_bucket --instance_type $instance_type --extra_params "$extra_params"
fi
cat parameters.json
echo Build completed on `date`
echo post build completed on `date`

artifacts:
files:
Expand Down
197 changes: 102 additions & 95 deletions src/emd/cfn/ecs/post_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,160 +3,167 @@
import json
import os
import argparse
from emd.models.utils.serialize_utils import load_extra_params

# Post build script for ECS, it will deploy the VPC and ECS cluster.

CFN_ROOT_PATH = 'cfn'
CFN_ROOT_PATH = "../cfn"
WAIT_SECONDS = 10
# CFN_ROOT_PATH = '../../cfn'
JSON_DOUBLE_QUOTE_REPLACE = '<!>'

def load_extra_params(string):
string = string.replace(JSON_DOUBLE_QUOTE_REPLACE,'"')
try:
return json.loads(string)
except json.JSONDecodeError:
raise argparse.ArgumentTypeError(f"Invalid dictionary format: {string}")

def dump_extra_params(d:dict):
return json.dumps(d).replace('"', JSON_DOUBLE_QUOTE_REPLACE)

def wait_for_stack_completion(client, stack_id, stack_name):
def wait_for_stack_completion(client, stack_name):
while True:
stack_status = client.describe_stacks(StackName=stack_id)['Stacks'][0]['StackStatus']
if stack_status in ['CREATE_COMPLETE', 'UPDATE_COMPLETE']:
response = client.describe_stacks(StackName=stack_name)
stack_status = response["Stacks"][0]["StackStatus"]
while stack_status.endswith("IN_PROGRESS"):
print(
f"Stack {stack_name} is currently {stack_status}. Waiting for completion..."
)
time.sleep(WAIT_SECONDS)
response = client.describe_stacks(StackName=stack_name)
stack_status = response["Stacks"][0]["StackStatus"]

if stack_status in ["CREATE_COMPLETE", "UPDATE_COMPLETE"]:
print(f"Stack {stack_name} deployment complete")
break
elif stack_status in ['CREATE_IN_PROGRESS', 'UPDATE_IN_PROGRESS']:
print(f"Stack {stack_name} is still being deployed...")
time.sleep(WAIT_SECONDS)
else:
raise Exception(f"Stack {stack_name} deployment failed with status {stack_status}")
raise Exception(
f"Post build stage failed. The stack {stack_name} is in an unexpected status: {stack_status}. Please visit the AWS CloudFormation Console to delete the stack."
)


def get_stack_outputs(client, stack_name):
response = client.describe_stacks(StackName=stack_name)
return response['Stacks'][0].get('Outputs', [])
return response["Stacks"][0].get("Outputs", [])


def create_or_update_stack(client, stack_name, template_path, parameters=[]):
try:
wait_for_stack_completion(client, stack_name)
response = client.describe_stacks(StackName=stack_name)
stack_status = response['Stacks'][0]['StackStatus']
if stack_status in ['ROLLBACK_COMPLETE', 'ROLLBACK_FAILED', 'DELETE_FAILED']:
print(f"Stack {stack_name} is in {stack_status} state. Deleting the stack to allow for recreation.")
client.delete_stack(StackName=stack_name)
while True:
try:
response = client.describe_stacks(StackName=stack_name)
stack_status = response['Stacks'][0]['StackStatus']
if stack_status == 'DELETE_IN_PROGRESS':
print(f"Stack {stack_name} is being deleted...")
time.sleep(WAIT_SECONDS)
else:
raise Exception(f"Unexpected status {stack_status} while waiting for stack deletion.")
except client.exceptions.ClientError as e:
if 'does not exist' in str(e):
print(f"Stack {stack_name} successfully deleted.")
break
else:
raise
while stack_status not in ['CREATE_COMPLETE', 'UPDATE_COMPLETE']:
if stack_status in ['CREATE_IN_PROGRESS', 'UPDATE_IN_PROGRESS']:
print(f"Stack {stack_name} is currently {stack_status}. Waiting for it to complete...")
time.sleep(WAIT_SECONDS)
response = client.describe_stacks(StackName=stack_name)
stack_status = response['Stacks'][0]['StackStatus']
else:
raise Exception(f"Stack {stack_name} is in an unexpected state: {stack_status}")
print(f"Stack {stack_name} already exists with status {stack_status}")
stack_status = response["Stacks"][0]["StackStatus"]

if stack_status in ["CREATE_COMPLETE", "UPDATE_COMPLETE"]:
print(f"Stack {stack_name} already exists. Proceeding with update.")
with open(template_path, "r") as template_file:
template_body = template_file.read()

response = client.update_stack(
StackName=stack_name,
TemplateBody=template_body,
Capabilities=["CAPABILITY_NAMED_IAM"],
Parameters=parameters
)

print(f"Started update of stack {stack_name}")
wait_for_stack_completion(client, stack_name)

except client.exceptions.ClientError as e:
if 'does not exist' in str(e):
if "does not exist" in str(e):
print(f"Stack {stack_name} does not exist. Proceeding with creation.")
with open(template_path, 'r') as template_file:
with open(template_path, "r") as template_file:
template_body = template_file.read()

response = client.create_stack(
StackName=stack_name,
TemplateBody=template_body,
Capabilities=['CAPABILITY_NAMED_IAM'],
Parameters=parameters
Capabilities=["CAPABILITY_NAMED_IAM"],
Parameters=parameters,
EnableTerminationProtection=True,
)

stack_id = response['StackId']
stack_id = response["StackId"]
print(f"Started deployment of stack {stack_name} with ID {stack_id}")
wait_for_stack_completion(client, stack_id, stack_name)
wait_for_stack_completion(client, stack_name)
else:
raise
raise Exception(
f"Post build stage failed. The stack {stack_name} is in an unexpected status: {stack_status}. Please visit the AWS CloudFormation Console to delete the stack."
)


def update_parameters_file(parameters_path, updates):
with open(parameters_path, 'r') as file:
with open(parameters_path, "r") as file:
data = json.load(file)

data['Parameters'].update(updates)
data["Parameters"].update(updates)

with open(parameters_path, 'w') as file:
with open(parameters_path, "w") as file:
json.dump(data, file, indent=4)


def deploy_vpc_template(region):
client = boto3.client('cloudformation', region_name=region)
stack_name = 'EMD-VPC'
template_path = f'{CFN_ROOT_PATH}/vpc/template.yaml'
client = boto3.client("cloudformation", region_name=region)
stack_name = "EMD-VPC"
template_path = f"{CFN_ROOT_PATH}/vpc/template.yaml"
create_or_update_stack(client, stack_name, template_path)
outputs = get_stack_outputs(client, stack_name)
vpc_id = None
subnets = None
for output in outputs:
if output['OutputKey'] == 'VPCID':
vpc_id = output['OutputValue']
elif output['OutputKey'] == 'Subnets':
subnets = output['OutputValue']
update_parameters_file('parameters.json', {'VPCID': vpc_id, 'Subnets': subnets})
if output["OutputKey"] == "VPCID":
vpc_id = output["OutputValue"]
elif output["OutputKey"] == "Subnets":
subnets = output["OutputValue"]
update_parameters_file("parameters.json", {"VPCID": vpc_id, "Subnets": subnets})
return vpc_id, subnets


def deploy_ecs_cluster_template(region, vpc_id, subnets):
client = boto3.client('cloudformation', region_name=region)
stack_name = 'EMD-ECS-Cluster'
template_path = f'{CFN_ROOT_PATH}/ecs/cluster.yaml'
create_or_update_stack(client, stack_name, template_path, [
{
'ParameterKey': 'VPCID',
'ParameterValue': vpc_id,
},
{
'ParameterKey': 'Subnets',
'ParameterValue': subnets,
},
])
client = boto3.client("cloudformation", region_name=region)
stack_name = "EMD-ECS-Cluster"
template_path = f"{CFN_ROOT_PATH}/ecs/cluster.yaml"
create_or_update_stack(
client,
stack_name,
template_path,
[
{
"ParameterKey": "VPCID",
"ParameterValue": vpc_id,
},
{
"ParameterKey": "Subnets",
"ParameterValue": subnets,
},
],
)

outputs = get_stack_outputs(client, stack_name)
for output in outputs:
update_parameters_file('parameters.json', {output['OutputKey']: output['OutputValue']})
update_parameters_file(
"parameters.json", {output["OutputKey"]: output["OutputValue"]}
)


def post_build():
parser = argparse.ArgumentParser()
parser.add_argument('--region', type=str, required=False)
parser.add_argument('--model_id', type=str, required=False)
parser.add_argument('--model_tag', type=str, required=False)
parser.add_argument('--framework_type', type=str, required=False)
parser.add_argument('--service_type', type=str, required=False)
parser.add_argument('--backend_type', type=str, required=False)
parser.add_argument('--model_s3_bucket', type=str, required=False)
parser.add_argument('--instance_type', type=str, required=False)
parser.add_argument('--extra_params', type=load_extra_params, required=False, default=os.environ.get("extra_params","{}"))
parser.add_argument("--region", type=str, required=False)
parser.add_argument("--model_id", type=str, required=False)
parser.add_argument("--model_tag", type=str, required=False)
parser.add_argument("--framework_type", type=str, required=False)
parser.add_argument("--service_type", type=str, required=False)
parser.add_argument("--backend_type", type=str, required=False)
parser.add_argument("--model_s3_bucket", type=str, required=False)
parser.add_argument("--instance_type", type=str, required=False)
parser.add_argument(
"--extra_params",
type=load_extra_params,
required=False,
default=os.environ.get("extra_params", "{}"),
)

args = parser.parse_args()

service_params = args.extra_params.get('service_params',{})
service_params = args.extra_params.get("service_params", {})

if 'vpc_id' not in service_params:
if "vpc_id" not in service_params:
vpc_id, subnets = deploy_vpc_template(args.region)
else:
vpc_id = service_params.get('vpc_id')
subnets = service_params.get('subnet_ids')
update_parameters_file('parameters.json', {'VPCID': vpc_id, 'Subnets': subnets})
vpc_id = service_params.get("vpc_id")
subnets = service_params.get("subnet_ids")
update_parameters_file("parameters.json", {"VPCID": vpc_id, "Subnets": subnets})

deploy_ecs_cluster_template(args.region, vpc_id, subnets)


if __name__ == "__main__":
post_build()
26 changes: 13 additions & 13 deletions src/emd/commands/deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,19 +521,19 @@ def deploy(
raise typer.Exit(0)

# log the deployment parameters
engine_info = model.find_current_engine(engine_type)
framework_info = model.find_current_framework(framework_type)

engine_info_str = json.dumps(engine_info,indent=2,ensure_ascii=False)
framework_info_str = json.dumps(framework_info, indent=2, ensure_ascii=False)
extra_params_info = json.dumps(extra_params, indent=2, ensure_ascii=False)
console.print(f"[bold blue]Deployment parameters:[/bold blue]")
console.print(f"[bold blue]model_id: {model_id},model_tag: {model_tag}[/bold blue]")
console.print(f"[bold blue]instance_type: {instance_type}[/bold blue]")
console.print(f"[bold blue]service_type: {service_type}[/bold blue]")
console.print(f"[bold blue]engine info:\n {engine_info_str}[/bold blue]")
console.print(f"[bold blue]framework info:\n {framework_info_str}[/bold blue]")
console.print(f"[bold blue]extra_params:\n {extra_params_info}[/bold blue]")
# engine_info = model.find_current_engine(engine_type)
# framework_info = model.find_current_framework(framework_type)

# engine_info_str = json.dumps(engine_info,indent=2,ensure_ascii=False)
# framework_info_str = json.dumps(framework_info, indent=2, ensure_ascii=False)
# extra_params_info = json.dumps(extra_params, indent=2, ensure_ascii=False)
# console.print(f"[bold blue]Deployment parameters:[/bold blue]")
# console.print(f"[bold blue]model_id: {model_id},model_tag: {model_tag}[/bold blue]")
# console.print(f"[bold blue]instance_type: {instance_type}[/bold blue]")
# console.print(f"[bold blue]service_type: {service_type}[/bold blue]")
# console.print(f"[bold blue]engine info:\n {engine_info_str}[/bold blue]")
# console.print(f"[bold blue]framework info:\n {framework_info_str}[/bold blue]")
# console.print(f"[bold blue]extra_params:\n {extra_params_info}[/bold blue]")
# Start pipeline execution
if service_type != ServiceType.LOCAL:
response = sdk_deploy(
Expand Down
10 changes: 3 additions & 7 deletions src/emd/commands/destroy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from rich.console import Console
from rich.panel import Panel

from emd.constants import MODEL_DEFAULT_TAG, VERSION_MODIFY
from emd.constants import MODEL_DEFAULT_TAG
from typing_extensions import Annotated
from emd.sdk.destroy import destroy as sdk_destroy
from emd.utils.decorators import catch_aws_credential_errors,check_emd_env_exist,load_aws_profile
Expand All @@ -25,11 +25,7 @@ def destroy(
],
model_tag: Annotated[
str, typer.Argument(help="Model tag")
] = MODEL_DEFAULT_TAG,
model_deploy_version: Annotated[
str, typer.Option("-v", "--deploy-version", help="The version of the model deployment to destroy"),
] = VERSION_MODIFY
] = MODEL_DEFAULT_TAG
):
model_deploy_version = convert_version_name_to_stack_name(model_deploy_version)
# console.print("[bold blue]Checking AWS environment...[/bold blue]")
sdk_destroy(model_id,model_tag=model_tag,waiting_until_complete=True, model_deploy_version=model_deploy_version)
sdk_destroy(model_id,model_tag=model_tag,waiting_until_complete=True)
Loading