diff --git a/tests/utils/kagglehub.py b/tests/utils/kagglehub.py index d7819dde..7a2a8995 100644 --- a/tests/utils/kagglehub.py +++ b/tests/utils/kagglehub.py @@ -1,3 +1,4 @@ +import json import os import threading import re @@ -7,6 +8,8 @@ from test.support.os_helper import EnvironmentVarGuard from http.server import BaseHTTPRequestHandler, HTTPServer +from kagglesdk.kaggle_env import get_endpoint, get_env + class KaggleAPIHandler(BaseHTTPRequestHandler): """ Fake Kaggle API server supporting the download endpoint. @@ -15,15 +18,18 @@ class KaggleAPIHandler(BaseHTTPRequestHandler): def do_HEAD(self): self.send_response(200) - def do_GET(self): - m = re.match("^/api/v1/models/(.+)/download/(.+)$", self.path) - if not m: + def do_POST(self): + content_length = int(self.headers.get('Content-Length', 0)) + body_bytes = self.rfile.read(content_length) + request_body = json.loads(body_bytes.decode('utf-8')) + + if self.path != "/api/v1/models.ModelApiService/DownloadModelInstanceVersion": self.send_response(404) self.wfile.write(bytes(f"Unhandled path: {self.path}", "utf-8")) return - model_handle = m.group(1) - path = m.group(2) + model_handle = f"{request_body["ownerSlug"]}/{request_body["modelSlug"]}/keras/{request_body["instanceSlug"]}/{request_body["versionNumber"]}" + path = request_body["path"] filepath = f"/input/tests/data/kagglehub/models/{model_handle}/{path}" if not os.path.isfile(filepath): self.send_error(404, "Internet is disabled in our tests " @@ -41,14 +47,12 @@ def do_GET(self): @contextmanager def create_test_kagglehub_server(): - endpoint = 'http://localhost:7777' env = EnvironmentVarGuard() - env.set('KAGGLE_API_ENDPOINT', endpoint) - test_server_address = urlparse(endpoint) + env.set('KAGGLE_API_ENVIRONMENT', 'TEST') with env: - if not test_server_address.hostname or not test_server_address.port: - msg = f"Invalid test server address: {endpoint}. You must specify a hostname & port" - raise ValueError(msg) + endpoint = get_endpoint(get_env()) + test_server_address = urlparse(endpoint) + with HTTPServer((test_server_address.hostname, test_server_address.port), KaggleAPIHandler) as httpd: threading.Thread(target=httpd.serve_forever).start()