Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 20 additions & 11 deletions tests/utils/kagglehub.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import os
import threading
import re
Expand All @@ -7,6 +8,8 @@
from test.support.os_helper import EnvironmentVarGuard
from http.server import BaseHTTPRequestHandler, HTTPServer

from kagglesdk.kaggle_env import get_endpoint, get_env

class KaggleAPIHandler(BaseHTTPRequestHandler):
"""
Fake Kaggle API server supporting the download endpoint.
Expand All @@ -15,15 +18,23 @@ class KaggleAPIHandler(BaseHTTPRequestHandler):
def do_HEAD(self):
self.send_response(200)

def do_GET(self):
m = re.match("^/api/v1/models/(.+)/download/(.+)$", self.path)
if not m:
def do_POST(self):
# 1. Get the content length from the headers
content_length = int(self.headers.get('Content-Length', 0))

# 2. Read the specified number of bytes from the input file (rfile)
body_bytes = self.rfile.read(content_length)

# 3. Decode the bytes to a string
request_body = json.loads(body_bytes.decode('utf-8'))

if self.path != "/api/v1/models.ModelApiService/DownloadModelInstanceVersion":
self.send_response(404)
self.wfile.write(bytes(f"Unhandled path: {self.path}", "utf-8"))
return

model_handle = m.group(1)
path = m.group(2)
model_handle = f"{request_body["ownerSlug"]}/{request_body["modelSlug"]}/keras/{request_body["instanceSlug"]}/{request_body["versionNumber"]}"
path = request_body["path"]
filepath = f"/input/tests/data/kagglehub/models/{model_handle}/{path}"
if not os.path.isfile(filepath):
self.send_error(404, "Internet is disabled in our tests "
Expand All @@ -41,14 +52,12 @@ def do_GET(self):

@contextmanager
def create_test_kagglehub_server():
endpoint = 'http://localhost:7777'
env = EnvironmentVarGuard()
env.set('KAGGLE_API_ENDPOINT', endpoint)
test_server_address = urlparse(endpoint)
env.set('KAGGLE_API_ENVIRONMENT', 'TEST')
with env:
if not test_server_address.hostname or not test_server_address.port:
msg = f"Invalid test server address: {endpoint}. You must specify a hostname & port"
raise ValueError(msg)
endpoint = get_endpoint(get_env())
test_server_address = urlparse(endpoint)

with HTTPServer((test_server_address.hostname, test_server_address.port), KaggleAPIHandler) as httpd:
threading.Thread(target=httpd.serve_forever).start()

Expand Down