Skip to content

Commit 4baa127

Browse files
committed
Fix checksum flag and tests
1 parent 4b1bdf4 commit 4baa127

File tree

3 files changed

+75
-7
lines changed

3 files changed

+75
-7
lines changed

openml/_api_calls.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,7 @@ def __read_url(
344344
url=url,
345345
data=data,
346346
md5_checksum=md5_checksum,
347+
check_digest=config.check_digest,
347348
)
348349

349350

@@ -356,12 +357,14 @@ def __is_checksum_equal(downloaded_file_binary: bytes, md5_checksum: str | None
356357
return md5_checksum == md5_checksum_download
357358

358359

359-
def _send_request( # noqa: C901, PLR0912
360+
def _send_request( # noqa: C901, PLR0912, PLR0913
360361
request_method: str,
361362
url: str,
362363
data: DATA_TYPE,
364+
*,
363365
files: FILE_ELEMENTS_TYPE | None = None,
364366
md5_checksum: str | None = None,
367+
check_digest: bool = True,
365368
) -> requests.Response:
366369
n_retries = max(1, config.connection_n_retries)
367370

@@ -386,16 +389,18 @@ def _send_request( # noqa: C901, PLR0912
386389

387390
__check_response(response=response, url=url, file_elements=files)
388391

389-
if request_method == "get" and not __is_checksum_equal(
390-
response.text.encode("utf-8"), md5_checksum
392+
if (
393+
request_method == "get"
394+
and check_digest
395+
and not __is_checksum_equal(response.text.encode("utf-8"), md5_checksum)
391396
):
392397
# -- Check if encoding is not UTF-8 perhaps
393398
if __is_checksum_equal(response.content, md5_checksum):
394399
raise OpenMLHashException(
395-
f"Checksum of downloaded file is unequal to the expected checksum"
400+
"Checksum of downloaded file is unequal to the expected checksum "
396401
f"{md5_checksum} because the text encoding is not UTF-8 when "
397-
f"downloading {url}. There might be a sever-sided issue with the file, "
398-
"see: https://github.com/openml/openml-python/issues/1180.",
402+
f"downloading {url}. There might be a server-sided issue with the "
403+
"file, see issue #1180."
399404
)
400405

401406
raise OpenMLHashException(

openml/config.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ class _Config(TypedDict):
3434
retry_policy: Literal["human", "robot"]
3535
connection_n_retries: int
3636
show_progress: bool
37+
check_digest: bool
3738

3839

3940
def _create_log_handlers(create_file_handler: bool = True) -> None: # noqa: FBT001, FBT002
@@ -154,6 +155,7 @@ def _resolve_default_cache_dir() -> Path:
154155
"retry_policy": "human",
155156
"connection_n_retries": 5,
156157
"show_progress": False,
158+
"check_digest": True, # Whether to check the md5 checksum of downloaded files
157159
}
158160

159161
# Default values are actually added here in the _setup() function which is
@@ -183,6 +185,7 @@ def get_server_base_url() -> str:
183185

184186
retry_policy: Literal["human", "robot"] = _defaults["retry_policy"]
185187
connection_n_retries: int = _defaults["connection_n_retries"]
188+
check_digest: bool = _defaults["check_digest"]
186189

187190

188191
def set_retry_policy(value: Literal["human", "robot"], n_retries: int | None = None) -> None:
@@ -340,6 +343,7 @@ def _setup(config: _Config | None = None) -> None:
340343
global _root_cache_directory # noqa: PLW0603
341344
global avoid_duplicate_runs # noqa: PLW0603
342345
global show_progress # noqa: PLW0603
346+
global check_digest # noqa: PLW0603
343347

344348
config_file = determine_config_file_path()
345349
config_dir = config_file.parent
@@ -361,6 +365,7 @@ def _setup(config: _Config | None = None) -> None:
361365
apikey = config["apikey"]
362366
server = config["server"]
363367
show_progress = config["show_progress"]
368+
check_digest = config["check_digest"]
364369
n_retries = int(config["connection_n_retries"])
365370

366371
set_retry_policy(config["retry_policy"], n_retries)
@@ -427,7 +432,7 @@ def _parse_config(config_file: str | Path) -> _Config:
427432
config_file_.seek(0)
428433
config.read_file(config_file_)
429434
configuration = dict(config.items("FAKE_SECTION"))
430-
for boolean_field in ["avoid_duplicate_runs", "show_progress"]:
435+
for boolean_field in ["avoid_duplicate_runs", "show_progress", "check_digest"]:
431436
if isinstance(config["FAKE_SECTION"][boolean_field], str):
432437
configuration[boolean_field] = config["FAKE_SECTION"].getboolean(boolean_field) # type: ignore
433438
return configuration # type: ignore
@@ -442,6 +447,7 @@ def get_config_as_dict() -> _Config:
442447
"connection_n_retries": connection_n_retries,
443448
"retry_policy": retry_policy,
444449
"show_progress": show_progress,
450+
"check_digest": check_digest,
445451
}
446452

447453

tests/test_utils/test_checksum.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import pytest
2+
import openml
3+
from unittest.mock import patch, Mock
4+
from openml.exceptions import OpenMLHashException
5+
6+
7+
def _mock_response():
8+
mock_response = Mock()
9+
mock_response.text = "hello"
10+
mock_response.content = b"hello"
11+
mock_response.status_code = 200
12+
mock_response.headers = {"Content-Encoding": "gzip"} # Required by headers check
13+
return mock_response
14+
15+
16+
@patch("requests.Session")
17+
def test_checksum_match(Session_class_mock):
18+
Session_class_mock.return_value.__enter__.return_value.get.return_value = _mock_response()
19+
20+
with openml.config.overwrite_config_context({"connection_n_retries": 1}): # to avoid retry delays
21+
openml._api_calls._send_request(
22+
request_method="get",
23+
url="/dummy",
24+
data={},
25+
md5_checksum="5d41402abc4b2a76b9719d911017c592",
26+
)
27+
28+
29+
@patch("requests.Session")
30+
def test_checksum_mismatch(Session_class_mock):
31+
Session_class_mock.return_value.__enter__.return_value.get.return_value = _mock_response()
32+
33+
with openml.config.overwrite_config_context({"connection_n_retries": 1}): # to avoid retry delays
34+
with pytest.raises(OpenMLHashException):
35+
openml._api_calls._send_request(
36+
request_method="get",
37+
url="/dummy",
38+
data={},
39+
md5_checksum="00000000000000000000000000000000",
40+
)
41+
42+
43+
@patch("requests.Session")
44+
def test_checksum_skipped_when_flag_off(Session_class_mock):
45+
Session_class_mock.return_value.__enter__.return_value.get.return_value = _mock_response()
46+
47+
with openml.config.overwrite_config_context({
48+
"check_digest": False,
49+
"connection_n_retries": 1, # to avoid retry delays
50+
}):
51+
# should NOT raise even though checksum mismatches
52+
openml._api_calls._send_request(
53+
request_method="get",
54+
url="/dummy",
55+
data={},
56+
md5_checksum="not-a-real-sum",
57+
)

0 commit comments

Comments
 (0)