From 51fc9fb9e97366d1197bab24bb09f641b6112639 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Mon, 24 Feb 2025 14:30:31 +0800 Subject: [PATCH 1/3] update monai hosting download Signed-off-by: Yiheng Wang --- monai/bundle/scripts.py | 37 +++++++++++++++++++++++++++++++------ 1 file changed, 31 insertions(+), 6 deletions(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 5089f0c045..bd7dc825ef 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -15,6 +15,7 @@ import json import os import re +import urllib import warnings import zipfile from collections.abc import Mapping, Sequence @@ -25,6 +26,7 @@ from textwrap import dedent from typing import Any, Callable +import requests import torch from torch.cuda import is_available @@ -206,6 +208,16 @@ def _download_from_monaihosting(download_path: Path, filename: str, version: str extractall(filepath=filepath, output_dir=download_path, has_base=True) +def _download_from_bundle_info(download_path: Path, filename: str, version: str, progress: bool) -> None: + bundle_info = get_bundle_info(name=filename, version=version) + if not bundle_info: + raise ValueError(f"Bundle info not found for {filename} v{version}.") + url = bundle_info["source"] + filepath = download_path / f"{filename}_v{version}.zip" + download_url(url=url, filepath=filepath, hash_val=None, progress=progress) + extractall(filepath=filepath, output_dir=download_path, has_base=True) + + def _add_ngc_prefix(name: str, prefix: str = "monai_") -> str: if name.startswith(prefix): return name @@ -307,10 +319,10 @@ def _get_latest_bundle_version_monaihosting(name): if has_requests: resp = requests_get(full_url) resp.raise_for_status() - else: - raise ValueError("NGC API requires requests package. Please install it.") - model_info = json.loads(resp.text) - return model_info["model"]["latestVersionIdStr"] + model_info = json.loads(resp.text) + return model_info["model"]["latestVersionIdStr"] + + raise ValueError("NGC API requires requests package. Please install it.") def _examine_monai_version(monai_version: str) -> tuple[bool, str]: @@ -416,7 +428,11 @@ def _get_latest_bundle_version( name = _add_ngc_prefix(name) return _get_latest_bundle_version_ngc(name) elif source == "monaihosting": - return _get_latest_bundle_version_monaihosting(name) + try: + return _get_latest_bundle_version_monaihosting(name) + except requests.exceptions.HTTPError: + # for monaihosting bundles, if cannot find the version, get from model zoo model_info.json + return get_bundle_versions(name)["latest_version"] elif source == "ngc_private": headers = kwargs.pop("headers", {}) name = _add_ngc_prefix(name) @@ -585,7 +601,16 @@ def download( name_ver = "_v".join([name_, version_]) if version_ is not None else name_ _download_from_github(repo=repo_, download_path=bundle_dir_, filename=name_ver, progress=progress_) elif source_ == "monaihosting": - _download_from_monaihosting(download_path=bundle_dir_, filename=name_, version=version_, progress=progress_) + try: + _download_from_monaihosting( + download_path=bundle_dir_, filename=name_, version=version_, progress=progress_ + ) + except urllib.error.HTTPError: + # for monaihosting bundles, if cannot download from default host, download according to bundle_info + _download_from_bundle_info( + download_path=bundle_dir_, filename=name_, version=version_, progress=progress_ + ) + elif source_ == "ngc": _download_from_ngc( download_path=bundle_dir_, From 4a4a73846ea4dd2166d9bfb33f2ab1253c7c0566 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Mon, 24 Feb 2025 14:43:16 +0800 Subject: [PATCH 2/3] solve errors Signed-off-by: Yiheng Wang --- monai/bundle/scripts.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index bd7dc825ef..eede02bb11 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -209,10 +209,10 @@ def _download_from_monaihosting(download_path: Path, filename: str, version: str def _download_from_bundle_info(download_path: Path, filename: str, version: str, progress: bool) -> None: - bundle_info = get_bundle_info(name=filename, version=version) + bundle_info = get_bundle_info(bundle_name=filename, version=version) if not bundle_info: raise ValueError(f"Bundle info not found for {filename} v{version}.") - url = bundle_info["source"] + url = bundle_info["browser_download_url"] filepath = download_path / f"{filename}_v{version}.zip" download_url(url=url, filepath=filepath, hash_val=None, progress=progress) extractall(filepath=filepath, output_dir=download_path, has_base=True) From 1d922535db31998e97020e8d1b7eced96b7204e9 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Mon, 24 Feb 2025 14:57:52 +0800 Subject: [PATCH 3/3] fix mypy error Signed-off-by: Yiheng Wang --- monai/bundle/scripts.py | 36 +++++++++++++++++------------------- 1 file changed, 17 insertions(+), 19 deletions(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index eede02bb11..b43f7e0fa0 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -26,7 +26,6 @@ from textwrap import dedent from typing import Any, Callable -import requests import torch from torch.cuda import is_available @@ -60,7 +59,7 @@ validate, _ = optional_import("jsonschema", name="validate") ValidationError, _ = optional_import("jsonschema.exceptions", name="ValidationError") Checkpoint, has_ignite = optional_import("ignite.handlers", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Checkpoint") -requests_get, has_requests = optional_import("requests", name="get") +requests, has_requests = optional_import("requests") onnx, _ = optional_import("onnx") huggingface_hub, _ = optional_import("huggingface_hub") @@ -234,7 +233,7 @@ def _get_all_download_files(request_url: str, headers: dict | None = None) -> li if not has_requests: raise ValueError("requests package is required, please install it.") headers = {} if headers is None else headers - response = requests_get(request_url, headers=headers) + response = requests.get(request_url, headers=headers) response.raise_for_status() model_info = json.loads(response.text) @@ -278,7 +277,7 @@ def _download_from_ngc_private( request_url = _get_ngc_private_bundle_url(model_name=filename, version=version, repo=repo) if has_requests: headers = {} if headers is None else headers - response = requests_get(request_url, headers=headers) + response = requests.get(request_url, headers=headers) response.raise_for_status() else: raise ValueError("NGC API requires requests package. Please install it.") @@ -301,7 +300,7 @@ def _get_ngc_token(api_key, retry=0): url = "https://authn.nvidia.com/token?service=ngc" headers = {"Accept": "application/json", "Authorization": "ApiKey " + api_key} if has_requests: - response = requests_get(url, headers=headers) + response = requests.get(url, headers=headers) if not response.ok: # retry 3 times, if failed, raise an error. if retry < 3: @@ -315,12 +314,15 @@ def _get_ngc_token(api_key, retry=0): def _get_latest_bundle_version_monaihosting(name): full_url = f"{MONAI_HOSTING_BASE_URL}/{name.lower()}" - requests_get, has_requests = optional_import("requests", name="get") if has_requests: - resp = requests_get(full_url) - resp.raise_for_status() - model_info = json.loads(resp.text) - return model_info["model"]["latestVersionIdStr"] + resp = requests.get(full_url) + try: + resp.raise_for_status() + model_info = json.loads(resp.text) + return model_info["model"]["latestVersionIdStr"] + except requests.exceptions.HTTPError: + # for monaihosting bundles, if cannot find the version, get from model zoo model_info.json + return get_bundle_versions(name)["latest_version"] raise ValueError("NGC API requires requests package. Please install it.") @@ -400,14 +402,14 @@ def _get_latest_bundle_version_ngc(name: str, repo: str | None = None, headers: version_header = {"Accept-Encoding": "gzip, deflate"} # Excluding 'zstd' to fit NGC requirements if headers: version_header.update(headers) - resp = requests_get(version_endpoint, headers=version_header) + resp = requests.get(version_endpoint, headers=version_header) resp.raise_for_status() model_info = json.loads(resp.text) latest_versions = _list_latest_versions(model_info) for version in latest_versions: file_endpoint = base_url + f"/{name.lower()}/versions/{version}/files/configs/metadata.json" - resp = requests_get(file_endpoint, headers=headers) + resp = requests.get(file_endpoint, headers=headers) metadata = json.loads(resp.text) resp.raise_for_status() # if the package version is not available or the model is compatible with the package version @@ -428,11 +430,7 @@ def _get_latest_bundle_version( name = _add_ngc_prefix(name) return _get_latest_bundle_version_ngc(name) elif source == "monaihosting": - try: - return _get_latest_bundle_version_monaihosting(name) - except requests.exceptions.HTTPError: - # for monaihosting bundles, if cannot find the version, get from model zoo model_info.json - return get_bundle_versions(name)["latest_version"] + return _get_latest_bundle_version_monaihosting(name) elif source == "ngc_private": headers = kwargs.pop("headers", {}) name = _add_ngc_prefix(name) @@ -817,9 +815,9 @@ def _get_all_bundles_info( if auth_token is not None: headers = {"Authorization": f"Bearer {auth_token}"} - resp = requests_get(request_url, headers=headers) + resp = requests.get(request_url, headers=headers) else: - resp = requests_get(request_url) + resp = requests.get(request_url) resp.raise_for_status() else: raise ValueError("requests package is required, please install it.")