Skip to content

Commit 9e7c5d6

Browse files
committed
deployment params
1 parent ba8fd78 commit 9e7c5d6

File tree

3 files changed

+91
-11
lines changed

3 files changed

+91
-11
lines changed

ads/aqua/modeldeployment/deployment.py

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
build_params_string,
2929
build_pydantic_error_message,
3030
find_restricted_params,
31-
get_combined_params,
3231
get_container_env_type,
3332
get_container_params_type,
3433
get_ocid_substring,
@@ -918,10 +917,31 @@ def _create(
918917
# The values provided by user will override the ones provided by default config
919918
env_var = {**config_env, **env_var}
920919

921-
# validate user provided params
922-
user_params = env_var.get("PARAMS", UNKNOWN)
920+
# SMM Parameter Resolution Logic
921+
# Check the raw user input from create_deployment_details to determine intent.
922+
# We cannot use the merged 'env_var' here because it may already contain defaults.
923+
user_input_env = create_deployment_details.env_var or {}
924+
user_input_params = user_input_env.get("PARAMS")
925+
926+
deployment_params = ""
927+
928+
if user_input_params is None:
929+
# Case 1: None (CLI default) -> Load full defaults from config
930+
logger.info("No PARAMS provided (None). Loading default SMM parameters.")
931+
deployment_params = config_params
932+
elif str(user_input_params).strip() == "":
933+
# Case 2: Empty String (UI Clear) -> Explicitly use no parameters
934+
logger.info("Empty PARAMS provided. Clearing all parameters.")
935+
deployment_params = ""
936+
else:
937+
# Case 3: Value Provided -> Use exact user value (No merging)
938+
logger.info(
939+
f"User provided PARAMS. Using exact user values: {user_input_params}"
940+
)
941+
deployment_params = user_input_params
923942

924-
if user_params:
943+
# Validate the resolved parameters
944+
if deployment_params:
925945
# todo: remove this check in the future version, logic to be moved to container_index
926946
if (
927947
container_type_key.lower()
@@ -935,16 +955,14 @@ def _create(
935955
)
936956

937957
restricted_params = find_restricted_params(
938-
params, user_params, container_type_key
958+
params, deployment_params, container_type_key
939959
)
940960
if restricted_params:
941961
raise AquaValueError(
942962
f"Parameters {restricted_params} are set by Aqua "
943963
f"and cannot be overridden or are invalid."
944964
)
945965

946-
deployment_params = get_combined_params(config_params, user_params)
947-
948966
params = f"{params} {deployment_params}".strip()
949967

950968
if isinstance(aqua_model, DataScienceModelGroup):
@@ -1212,7 +1230,7 @@ def _create_deployment(
12121230

12131231
# we arbitrarily choose last 8 characters of OCID to identify MD in telemetry
12141232
deployment_short_ocid = get_ocid_substring(deployment_id, key_len=8)
1215-
1233+
12161234
# Prepare telemetry kwargs
12171235
telemetry_kwargs = {"ocid": deployment_short_ocid}
12181236

@@ -2048,9 +2066,11 @@ def recommend_shape(self, **kwargs) -> Union[Table, ShapeRecommendationReport]:
20482066
self.telemetry.record_event_async(
20492067
category="aqua/deployment",
20502068
action="recommend_shape",
2051-
detail=get_ocid_substring(model_id, key_len=8)
2052-
if is_valid_ocid(ocid=model_id)
2053-
else model_id,
2069+
detail=(
2070+
get_ocid_substring(model_id, key_len=8)
2071+
if is_valid_ocid(ocid=model_id)
2072+
else model_id
2073+
),
20542074
**kwargs,
20552075
)
20562076

models.json

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
[
2+
{
3+
"model_id": "ocid1.datasciencemodel.oc1.iad.amaaaaaam3xyxziacjn3gsjl4mmvesis5pjeu43lj2vyzjxluoffuqm734da",
4+
"gpu_count": 1,
5+
"model_name": "llama3-8b-instruct",
6+
"fine_tune_weights": [
7+
{
8+
"model_id": "ocid1.datasciencemodel.oc1.iad.amaaaaaav66vvniabwlacbrsjukmrwk7mmec5ukpumcuefxmclz6suvygywq",
9+
"model_name": "my-llama-v3.1-8b-instruct-ft"
10+
},
11+
{
12+
"model_id": "ocid1.datasciencemodel.oc1.iad.amaaaaaav66vvniabxhq42ft6ujro4vb5mwa5kelegkj6lle3g6hpaleomeq",
13+
"model_name": "llama-oasst-ft"
14+
}
15+
]
16+
}
17+
]
18+

test_inference.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import json
2+
import requests
3+
import ads
4+
5+
# Set up OCI security token authentication
6+
ads.set_auth("security_token")
7+
8+
# Your Model Deployment OCID and endpoint URL
9+
md_ocid = "ocid1.datasciencemodeldeploymentint.oc1.iad.amaaaaaav66vvniasakhgqe4hk6eqgci7jmj2nvxjzldaqlnb7ji7vjr5p6a"
10+
endpoint = "https://modeldeployment-int.us-ashburn-1.oci.oc-test.com/ocid1.datasciencemodeldeploymentint.oc1.iad.amaaaaaav66vvniasakhgqe4hk6eqgci7jmj2nvxjzldaqlnb7ji7vjr5p6a/predict"
11+
12+
# OCI request signer
13+
auth = ads.common.auth.default_signer()["signer"]
14+
15+
16+
def predict(model_name):
17+
predict_data = {
18+
"model": model_name,
19+
"prompt": "[user] Write a SQL query to answer the question based on the table schema.\n\ncontext: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\nquestion: Name the ICAO for lilongwe international airport [/user] [assistant]",
20+
"max_tokens": 100,
21+
"temperature": 0,
22+
}
23+
predict_headers = {"cx": "application/json", "opc-request-id": "test-id"}
24+
response = requests.post(
25+
endpoint,
26+
headers=predict_headers,
27+
data=json.dumps(predict_data),
28+
auth=auth,
29+
verify=False, # Use verify=True in production!
30+
)
31+
print("Status:", response.status_code)
32+
try:
33+
print(json.dumps(response.json(), indent=2))
34+
except Exception as e:
35+
print("Error parsing JSON:", e)
36+
print("Response.text:", response.text)
37+
38+
39+
if __name__ == "__main__":
40+
ft_model_name = "my-llama-v3.1-8b-instruct-ft"
41+
print(f"Testing FT model: {ft_model_name}")
42+
predict(ft_model_name)

0 commit comments

Comments
 (0)