|
7 | 7 | import time |
8 | 8 | from .progressbar import ProgressBar |
9 | 9 | import yaml |
10 | | -from .data import Model |
11 | 10 | from .config import * |
12 | 11 | import traceback |
13 | 12 | from pprint import pformat |
14 | 13 |
|
15 | 14 | def download_model(arg_tuple): |
16 | 15 | model_id, model_run_info = arg_tuple |
17 | 16 | try: |
18 | | - model_json = requests.get(MDB_MODEL_DOWNLOAD_URL.format(model_id=model_id)).json() |
19 | | - model = Model( |
20 | | - *( |
21 | | - model_json[key] |
22 | | - for key in ("object_id", "object_name", "object_created", "object_ver_date") |
23 | | - ) |
24 | | - ) |
25 | | - url = None |
26 | | - for att in model_json["object_attribute_values"]: |
27 | | - if att["attribute_id"] == 23: |
28 | | - url = att["value"] |
29 | | - break |
30 | | - # print(model.id) |
31 | 17 | model_zip_uri = os.path.join( |
32 | | - MODELS_ZIP_DIR, "{model_id}.zip".format(model_id=model.id) |
| 18 | + MODELS_ZIP_DIR, "{model_id}.zip".format(model_id=model_id)) |
| 19 | + |
| 20 | + suffix = model_run_info["github"] if "github" in model_run_info else "master" |
| 21 | + github_url = "https://github.com/ModelDBRepository/{model_id}/archive/refs/heads/{suffix}.zip".format( |
| 22 | + model_id=model_id, suffix=suffix |
33 | 23 | ) |
34 | | - with open(model_zip_uri, "wb+") as zipfile: |
35 | | - zipfile.write(base64.standard_b64decode(url["file_content"])) |
36 | | - |
37 | | - if "github" in model_run_info: |
38 | | - # This means we should try to replace the version of the model that |
39 | | - # we downloaded from the ModelDB API just above with a version from |
40 | | - # GitHub |
41 | | - github = model_run_info["github"] |
42 | | - if github == "default": |
43 | | - suffix = "" |
44 | | - elif github.startswith("pull/"): |
45 | | - pr_number = int(github[5:]) |
46 | | - suffix = "/pull/{}/head".format(pr_number) |
47 | | - else: |
48 | | - raise Exception("Invalid value for github key: {}".format(github)) |
49 | | - github_url = "https://api.github.com/repos/ModelDBRepository/{model_id}/zipball{suffix}".format( |
50 | | - model_id=model_id, suffix=suffix |
51 | | - ) |
52 | | - # Replace the local file `model_zip_uri` with the zip file we |
53 | | - # downloaded from `github_url` |
54 | | - num_attempts = 3 |
55 | | - status_codes = [] |
56 | | - for _ in range(num_attempts): |
57 | | - github_response = requests.get(github_url) |
58 | | - status_codes.append(github_response.status_code) |
59 | | - if github_response.status_code == requests.codes.ok: |
60 | | - break |
61 | | - time.sleep(5) |
62 | | - else: |
63 | | - raise Exception( |
64 | | - "Failed to download {} with status codes {}".format( |
65 | | - github_url, status_codes |
66 | | - ) |
67 | | - ) |
68 | | - with open(model_zip_uri, "wb+") as zipfile: |
69 | | - zipfile.write(github_response.content) |
| 24 | + |
| 25 | + # download github_url to model_zip_uri |
| 26 | + logging.info("Downloading model {} from {}".format(model_id, github_url)) |
| 27 | + response = requests.get(github_url, stream=True) |
| 28 | + if response.status_code != 200: |
| 29 | + raise Exception("Failed to download model: {}".format(response.text)) |
| 30 | + with open(model_zip_uri, "wb") as f: |
| 31 | + for chunk in response.iter_content(chunk_size=1024): |
| 32 | + if chunk: |
| 33 | + f.write(chunk) |
| 34 | + f.flush() |
| 35 | + logging.info("Downloaded model {} to {}".format(model_id, model_zip_uri)) |
70 | 36 | except Exception as e: # noqa |
71 | | - model = e |
| 37 | + github_url = e |
72 | 38 |
|
73 | | - return model_id, model |
| 39 | + return model_id, github_url |
74 | 40 |
|
75 | 41 |
|
76 | 42 | class ModelDB(object): |
@@ -105,11 +71,15 @@ def _download_models(self, model_list=None): |
105 | 71 | [(model_id, self._run_instr.get(model_id, {})) for model_id in models], |
106 | 72 | ) |
107 | 73 | download_err = {} |
108 | | - for model_id, model in ProgressBar.iter(processed_models, len(models)): |
109 | | - if not isinstance(model, Exception): |
110 | | - self._metadata[model_id] = model |
| 74 | + for model_id, model_url in ProgressBar.iter(processed_models, len(models)): |
| 75 | + |
| 76 | + if not isinstance(model_url, Exception): |
| 77 | + model_meta = {} |
| 78 | + model_meta["id"] = model_id |
| 79 | + model_meta["url"] = model_url |
| 80 | + self._metadata[model_id] = model_meta |
111 | 81 | else: |
112 | | - download_err[model_id] = model |
| 82 | + download_err[model_id] = model_url |
113 | 83 |
|
114 | 84 | if download_err: |
115 | 85 | logging.error("Error downloading models:") |
|
0 commit comments