Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions openml/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,17 @@
import unittest
from pathlib import Path
from typing import ClassVar
from typing_extensions import Literal

import requests

import openml
from openml.exceptions import OpenMLServerException
from openml.tasks import TaskType

# Type alias for entity types that can be tracked
EntityType = Literal["run", "data", "flow", "task", "study", "user"]


def _check_dataset(dataset: dict) -> None:
assert isinstance(dataset, dict)
Expand All @@ -37,8 +41,7 @@ class TestBase(unittest.TestCase):
Hopefully soon allows using a test server, not the production server.
"""

# TODO: This could be made more explcit with a TypedDict instead of list[str | int]
publish_tracker: ClassVar[dict[str, list[str | int]]] = {
publish_tracker: ClassVar[dict[EntityType, list[int]]] = {
"run": [],
"data": [],
"flow": [],
Expand Down Expand Up @@ -133,7 +136,7 @@ def tearDown(self) -> None:
@classmethod
def _mark_entity_for_removal(
cls,
entity_type: str,
entity_type: EntityType,
entity_id: int,
entity_name: str | None = None,
) -> None:
Expand All @@ -153,7 +156,7 @@ def _mark_entity_for_removal(
cls.flow_name_tracker.append(entity_name)

@classmethod
def _delete_entity_from_tracker(cls, entity_type: str, entity: int) -> None:
def _delete_entity_from_tracker(cls, entity_type: EntityType, entity: int) -> None:
"""Deletes entity records from the static file_tracker

Given an entity type and corresponding ID, deletes all entries, including
Expand Down