Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 20 additions & 8 deletions examples/Advanced/tasks_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,15 @@
#
# We will start by simply listing only *supervised classification* tasks.
#
# **openml.tasks.list_tasks()** returns a dictionary of dictionaries by default, but we
# request a
# **openml.list_all("task")** (or **openml.tasks.list_tasks()**) returns a dictionary of
# dictionaries by default, but we request a
# [pandas dataframe](https://pandas.pydata.org/pandas-docs/stable/generated/pandas.DataFrame.html)
# instead to have better visualization capabilities and easier access:

# %%
tasks = openml.tasks.list_tasks(task_type=TaskType.SUPERVISED_CLASSIFICATION)
tasks = openml.list_all("task", task_type=TaskType.SUPERVISED_CLASSIFICATION)
# Legacy path still works:
# tasks = openml.tasks.list_tasks(task_type=TaskType.SUPERVISED_CLASSIFICATION)
print(tasks.columns)
print(f"First 5 of {len(tasks)} tasks:")
print(tasks.head())
Expand Down Expand Up @@ -66,23 +68,29 @@
# Similar to listing tasks by task type, we can list tasks by tags:

# %%
tasks = openml.tasks.list_tasks(tag="OpenML100")
tasks = openml.list_all("task", tag="OpenML100")
# Legacy path still works:
# tasks = openml.tasks.list_tasks(tag="OpenML100")
print(f"First 5 of {len(tasks)} tasks:")
print(tasks.head())

# %% [markdown]
# Furthermore, we can list tasks based on the dataset id:

# %%
tasks = openml.tasks.list_tasks(data_id=1471)
tasks = openml.list_all("task", data_id=1471)
# Legacy path still works:
# tasks = openml.tasks.list_tasks(data_id=1471)
print(f"First 5 of {len(tasks)} tasks:")
print(tasks.head())

# %% [markdown]
# In addition, a size limit and an offset can be applied both separately and simultaneously:

# %%
tasks = openml.tasks.list_tasks(size=10, offset=50)
tasks = openml.list_all("task", size=10, offset=50)
# Legacy path still works:
# tasks = openml.tasks.list_tasks(size=10, offset=50)
print(tasks)

# %% [markdown]
Expand All @@ -98,7 +106,9 @@
# Finally, it is also possible to list all tasks on OpenML with:

# %%
tasks = openml.tasks.list_tasks()
tasks = openml.list_all("task")
# Legacy path still works:
# tasks = openml.tasks.list_tasks()
print(len(tasks))

# %% [markdown]
Expand All @@ -118,7 +128,9 @@

# %%
task_id = 31
task = openml.tasks.get_task(task_id)
task = openml.get(task_id, object_type="task")
# Legacy path still works:
# task = openml.tasks.get_task(task_id)

# %%
# Properties of the task are stored as member variables:
Expand Down
12 changes: 10 additions & 2 deletions examples/Basics/simple_datasets_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,23 @@
# ## List datasets stored on OpenML

# %%
datasets_df = openml.datasets.list_datasets()
datasets_df = openml.list_all("dataset")
print(datasets_df.head(n=10))

# Legacy path still works:
# datasets_df = openml.datasets.list_datasets()

# %% [markdown]
# ## Download a dataset

# %%
# Iris dataset https://www.openml.org/d/61
dataset = openml.datasets.get_dataset(dataset_id=61)
dataset = openml.get(61)
# You can also fetch by name:
# dataset = openml.get("Fashion-MNIST")

# Legacy path still works:
# dataset = openml.datasets.get_dataset(dataset_id=61)

# Print a summary
print(
Expand Down
15 changes: 14 additions & 1 deletion examples/Basics/simple_flows_and_runs_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,25 @@
# %%
openml.config.start_using_configuration_for_example()

# %% [markdown]
# ## Quick: list flows and runs via unified entrypoints

# %%
flows_df = openml.list_all("flow", size=3)
print(flows_df.head())

runs_df = openml.list_all("run", size=3)
print(runs_df.head())

# %% [markdown]
# ## Train a machine learning model and evaluate it
# NOTE: We are using task 119 from the test server: https://test.openml.org/d/20

# %%
task = openml.tasks.get_task(119)
task = openml.get(119, object_type="task")

# Legacy path still works:
# task = openml.tasks.get_task(119)

# Get the data
dataset = task.get_dataset()
Expand Down
5 changes: 4 additions & 1 deletion examples/Basics/simple_tasks_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@
# [supervised classification on credit-g](https://www.openml.org/search?type=task&id=31&source_data.data_id=31):

# %%
task = openml.tasks.get_task(31)
task = openml.get(31, object_type="task")

# Legacy path still works:
# task = openml.tasks.get_task(31)

# %% [markdown]
# Get the dataset and its data from the task.
Expand Down
100 changes: 100 additions & 0 deletions openml/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
# License: BSD 3-Clause
from __future__ import annotations

from typing import Any, Callable, Dict

from . import (
_api_calls,
config,
Expand Down Expand Up @@ -49,6 +51,102 @@
OpenMLTask,
)

ListDispatcher = Dict[str, Callable[..., Any]]
GetDispatcher = Dict[str, Callable[..., Any]]

_LIST_DISPATCH: ListDispatcher = {
"dataset": datasets.functions.list_datasets,
"task": tasks.functions.list_tasks,
"flow": flows.functions.list_flows,
"run": runs.functions.list_runs,
}

_GET_DISPATCH: GetDispatcher = {
"dataset": datasets.functions.get_dataset,
"task": tasks.functions.get_task,
"flow": flows.functions.get_flow,
"run": runs.functions.get_run,
}


def list_all(object_type: str, /, **kwargs: Any) -> Any:
"""List OpenML objects by type (e.g., datasets, tasks, flows, runs).

This is a convenience dispatcher that forwards to the existing type-specific
``list_*`` functions. Existing imports remain available for backward compatibility.

Parameters
----------
object_type : str
The type of object to list. Must be one of 'dataset', 'task', 'flow', 'run'.
**kwargs : Any
Additional arguments passed to the underlying list function.

Returns
-------
Any
The result from the type-specific list function (typically a DataFrame).

Raises
------
ValueError
If object_type is not one of the supported types.
"""
if not isinstance(object_type, str):
raise TypeError(f"object_type must be a string, got {type(object_type).__name__}")

func = _LIST_DISPATCH.get(object_type.lower())
if func is None:
valid_types = ", ".join(repr(k) for k in _LIST_DISPATCH)
raise ValueError(
f"Unsupported object_type {object_type!r}; expected one of {valid_types}.",
)

return func(**kwargs)


def get(identifier: int | str, *, object_type: str = "dataset", **kwargs: Any) -> Any:
"""Get an OpenML object by identifier.

Parameters
----------
identifier : int | str
The ID or name of the object to retrieve.
object_type : str, default="dataset"
The type of object to get. Must be one of 'dataset', 'task', 'flow', 'run'.
**kwargs : Any
Additional arguments passed to the underlying get function.

Returns
-------
Any
The requested OpenML object.

Raises
------
ValueError
If object_type is not one of the supported types.

Examples
--------
>>> openml.get(61) # Get dataset 61 (default object_type="dataset")
>>> openml.get("Fashion-MNIST") # Get dataset by name
>>> openml.get(31, object_type="task") # Get task 31
>>> openml.get(10, object_type="flow") # Get flow 10
>>> openml.get(20, object_type="run") # Get run 20
"""
if not isinstance(object_type, str):
raise TypeError(f"object_type must be a string, got {type(object_type).__name__}")

func = _GET_DISPATCH.get(object_type.lower())
if func is None:
valid_types = ", ".join(repr(k) for k in _GET_DISPATCH)
raise ValueError(
f"Unsupported object_type {object_type!r}; expected one of {valid_types}.",
)

return func(identifier, **kwargs)


def populate_cache(
task_ids: list[int] | None = None,
Expand Down Expand Up @@ -120,4 +218,6 @@ def populate_cache(
"utils",
"_api_calls",
"__version__",
"get",
"list_all",
]
24 changes: 24 additions & 0 deletions tests/test_openml/test_openml.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,27 @@ def test_populate_cache(
assert task_mock.call_count == 2
for argument, fixture in zip(task_mock.call_args_list, [(1,), (2,)]):
assert argument[0] == fixture

@mock.patch("openml.tasks.functions.list_tasks")
@mock.patch("openml.datasets.functions.list_datasets")
def test_list_dispatch(self, list_datasets_mock, list_tasks_mock):
openml.list("dataset", output_format="dataframe")
list_datasets_mock.assert_called_once_with(output_format="dataframe")

openml.list("task", size=5)
list_tasks_mock.assert_called_once_with(size=5)

@mock.patch("openml.tasks.functions.get_task")
@mock.patch("openml.datasets.functions.get_dataset")
def test_get_dispatch(self, get_dataset_mock, get_task_mock):
openml.get("dataset", 61)
get_dataset_mock.assert_called_with(61)

openml.get("dataset", "Fashion-MNIST", version=2)
get_dataset_mock.assert_called_with("Fashion-MNIST", version=2)

openml.get("Fashion-MNIST")
get_dataset_mock.assert_called_with("Fashion-MNIST")

openml.get("task", 31)
get_task_mock.assert_called_with(31)