diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 5089f0c045..b43f7e0fa0 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -15,6 +15,7 @@ import json import os import re +import urllib import warnings import zipfile from collections.abc import Mapping, Sequence @@ -58,7 +59,7 @@ validate, _ = optional_import("jsonschema", name="validate") ValidationError, _ = optional_import("jsonschema.exceptions", name="ValidationError") Checkpoint, has_ignite = optional_import("ignite.handlers", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Checkpoint") -requests_get, has_requests = optional_import("requests", name="get") +requests, has_requests = optional_import("requests") onnx, _ = optional_import("onnx") huggingface_hub, _ = optional_import("huggingface_hub") @@ -206,6 +207,16 @@ def _download_from_monaihosting(download_path: Path, filename: str, version: str extractall(filepath=filepath, output_dir=download_path, has_base=True) +def _download_from_bundle_info(download_path: Path, filename: str, version: str, progress: bool) -> None: + bundle_info = get_bundle_info(bundle_name=filename, version=version) + if not bundle_info: + raise ValueError(f"Bundle info not found for {filename} v{version}.") + url = bundle_info["browser_download_url"] + filepath = download_path / f"{filename}_v{version}.zip" + download_url(url=url, filepath=filepath, hash_val=None, progress=progress) + extractall(filepath=filepath, output_dir=download_path, has_base=True) + + def _add_ngc_prefix(name: str, prefix: str = "monai_") -> str: if name.startswith(prefix): return name @@ -222,7 +233,7 @@ def _get_all_download_files(request_url: str, headers: dict | None = None) -> li if not has_requests: raise ValueError("requests package is required, please install it.") headers = {} if headers is None else headers - response = requests_get(request_url, headers=headers) + response = requests.get(request_url, headers=headers) response.raise_for_status() model_info = json.loads(response.text) @@ -266,7 +277,7 @@ def _download_from_ngc_private( request_url = _get_ngc_private_bundle_url(model_name=filename, version=version, repo=repo) if has_requests: headers = {} if headers is None else headers - response = requests_get(request_url, headers=headers) + response = requests.get(request_url, headers=headers) response.raise_for_status() else: raise ValueError("NGC API requires requests package. Please install it.") @@ -289,7 +300,7 @@ def _get_ngc_token(api_key, retry=0): url = "https://authn.nvidia.com/token?service=ngc" headers = {"Accept": "application/json", "Authorization": "ApiKey " + api_key} if has_requests: - response = requests_get(url, headers=headers) + response = requests.get(url, headers=headers) if not response.ok: # retry 3 times, if failed, raise an error. if retry < 3: @@ -303,14 +314,17 @@ def _get_ngc_token(api_key, retry=0): def _get_latest_bundle_version_monaihosting(name): full_url = f"{MONAI_HOSTING_BASE_URL}/{name.lower()}" - requests_get, has_requests = optional_import("requests", name="get") if has_requests: - resp = requests_get(full_url) - resp.raise_for_status() - else: - raise ValueError("NGC API requires requests package. Please install it.") - model_info = json.loads(resp.text) - return model_info["model"]["latestVersionIdStr"] + resp = requests.get(full_url) + try: + resp.raise_for_status() + model_info = json.loads(resp.text) + return model_info["model"]["latestVersionIdStr"] + except requests.exceptions.HTTPError: + # for monaihosting bundles, if cannot find the version, get from model zoo model_info.json + return get_bundle_versions(name)["latest_version"] + + raise ValueError("NGC API requires requests package. Please install it.") def _examine_monai_version(monai_version: str) -> tuple[bool, str]: @@ -388,14 +402,14 @@ def _get_latest_bundle_version_ngc(name: str, repo: str | None = None, headers: version_header = {"Accept-Encoding": "gzip, deflate"} # Excluding 'zstd' to fit NGC requirements if headers: version_header.update(headers) - resp = requests_get(version_endpoint, headers=version_header) + resp = requests.get(version_endpoint, headers=version_header) resp.raise_for_status() model_info = json.loads(resp.text) latest_versions = _list_latest_versions(model_info) for version in latest_versions: file_endpoint = base_url + f"/{name.lower()}/versions/{version}/files/configs/metadata.json" - resp = requests_get(file_endpoint, headers=headers) + resp = requests.get(file_endpoint, headers=headers) metadata = json.loads(resp.text) resp.raise_for_status() # if the package version is not available or the model is compatible with the package version @@ -585,7 +599,16 @@ def download( name_ver = "_v".join([name_, version_]) if version_ is not None else name_ _download_from_github(repo=repo_, download_path=bundle_dir_, filename=name_ver, progress=progress_) elif source_ == "monaihosting": - _download_from_monaihosting(download_path=bundle_dir_, filename=name_, version=version_, progress=progress_) + try: + _download_from_monaihosting( + download_path=bundle_dir_, filename=name_, version=version_, progress=progress_ + ) + except urllib.error.HTTPError: + # for monaihosting bundles, if cannot download from default host, download according to bundle_info + _download_from_bundle_info( + download_path=bundle_dir_, filename=name_, version=version_, progress=progress_ + ) + elif source_ == "ngc": _download_from_ngc( download_path=bundle_dir_, @@ -792,9 +815,9 @@ def _get_all_bundles_info( if auth_token is not None: headers = {"Authorization": f"Bearer {auth_token}"} - resp = requests_get(request_url, headers=headers) + resp = requests.get(request_url, headers=headers) else: - resp = requests_get(request_url) + resp = requests.get(request_url) resp.raise_for_status() else: raise ValueError("requests package is required, please install it.")