Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@ releases are available on [conda-forge](https://anaconda.org/conda-forge/dags).

## 0.5.0

- :gh:`67` Change `dict` annotations to `Mapping`; do not require string annotations
from users (:ghuser:`hmgaudecker`).

- :gh:`66` Improve linting and development setup (:ghuser:`hmgaudecker`).

- :gh:`62` Drop Python 3.10 support, improve typing thanks to requiring current networkx
(:ghuser:`hmgaudecker`).

Expand Down
607 changes: 304 additions & 303 deletions pixi.lock

Large diffs are not rendered by default.

77 changes: 11 additions & 66 deletions src/dags/annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@

from typing import TYPE_CHECKING, Any, Literal, overload

from dags.exceptions import NonStringAnnotationError

if TYPE_CHECKING:
from collections.abc import Callable

Expand Down Expand Up @@ -99,77 +97,24 @@ def get_annotations(
return {arg: annotations.get(arg, default) for arg in ["return", *free_arguments]}


def verify_annotations_are_strings(
annotations: dict[str, str], function_name: str
) -> None:
"""Verify that all type annotations are strings.
def ensure_annotations_are_strings(annotations: dict[str, Any]) -> dict[str, str]:
"""Ensure all type annotations are strings, converting if necessary.

Raises NonStringAnnotationError with a helpful message if any annotation
is not a string, suggesting the use of `from __future__ import annotations`.
In Python 3.14+, annotations may be evaluated at runtime rather than stored
as strings. This function converts any non-string annotations to their string
representation.

Args:
annotations: Dictionary of annotation names to their values.
function_name: Name of the function, used in error messages.

Raises:
------
NonStringAnnotationError: If any annotation value is not a string.
Returns:
-------
Dictionary with all annotation values as strings.

"""
# If all annotations are strings, we are done.
if all(isinstance(v, str) for v in annotations.values()):
return

non_string_annotations = [
k for k, v in annotations.items() if not isinstance(v, str)
]
arg_annotations = {k: v for k, v in annotations.items() if k != "return"}
return_annotation = annotations["return"]

# Create a representation of the signature with string annotations
# ----------------------------------------------------------------------------------
stringified_arg_annotations = []
for k, v in arg_annotations.items():
if k in non_string_annotations:
stringified_arg_annotations.append(f"{k}: '{_get_str_repr(v)}'")
else:
annot = f"{k}: '{v}'"
stringified_arg_annotations.append(annot)

if "return" in non_string_annotations:
stringified_return_annotation = f"'{_get_str_repr(return_annotation)}'"
else:
stringified_return_annotation = f"'{return_annotation}'"

stringified_signature = (
f"{function_name}({', '.join(stringified_arg_annotations)}) -> "
f"{stringified_return_annotation}"
)

# Create message on which argument and/or return annotation is invalid
# ----------------------------------------------------------------------------------
invalid_arg_annotations = [k for k in non_string_annotations if k != "return"]
if invalid_arg_annotations:
s = "s" if len(invalid_arg_annotations) > 1 else ""
invalid_arg_msg = f"argument{s} ({', '.join(invalid_arg_annotations)})"
else:
invalid_arg_msg = ""

invalid_annotations_msg = ""
if invalid_arg_msg and "return" in non_string_annotations:
invalid_annotations_msg = f"{invalid_arg_msg} and the return value"
elif invalid_arg_msg:
invalid_annotations_msg = invalid_arg_msg
elif "return" in non_string_annotations:
invalid_annotations_msg = "return value"

raise NonStringAnnotationError(
f"All function annotations must be strings. The annotations for the "
f"{invalid_annotations_msg} are not strings.\nA simple way for Python to treat "
"type annotations as strings is to add\n\n\tfrom __future__ import annotations"
"\n\nat the top of your file. Alternatively, you can do it manually by "
f"enclosing the annotations in quotes:\n\n\t{stringified_signature}."
)
return {
k: v if isinstance(v, str) else _get_str_repr(v) for k, v in annotations.items()
}


def _get_str_repr(obj: object) -> str:
Expand Down
42 changes: 19 additions & 23 deletions src/dags/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,16 @@
import functools
import inspect
import warnings
from collections.abc import Mapping
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Literal, cast

import networkx as nx

from dags.annotations import (
ensure_annotations_are_strings,
get_annotations,
get_free_arguments,
verify_annotations_are_strings,
)
from dags.exceptions import (
AnnotationMismatchError,
Expand Down Expand Up @@ -68,15 +69,13 @@ class FunctionExecutionInfo:
func: Callable[..., Any]
verify_annotations: bool = False

def __post_init__(self) -> None:
"""Verify that the annotations are strings."""
if self.verify_annotations:
verify_annotations_are_strings(self.annotations, self.name)

@functools.cached_property
def annotations(self) -> dict[str, str]:
"""The annotations of the function."""
return get_annotations(self.func)
raw_annotations = get_annotations(self.func)
if self.verify_annotations:
return ensure_annotations_are_strings(raw_annotations)
return raw_annotations

@property
def arguments(self) -> list[str]:
Expand All @@ -95,7 +94,7 @@ def return_annotation(self) -> str:


def concatenate_functions( # noqa: PLR0913
functions: dict[str, Callable[..., Any]] | list[Callable[..., Any]],
functions: Mapping[str, Callable[..., Any]] | list[Callable[..., Any]],
targets: str | list[str] | None = None,
*,
dag: nx.DiGraph[str] | None = None,
Expand Down Expand Up @@ -183,7 +182,7 @@ def concatenate_functions( # noqa: PLR0913


def create_dag(
functions: dict[str, Callable[..., Any]] | list[Callable[..., Any]],
functions: Mapping[str, Callable[..., Any]] | list[Callable[..., Any]],
targets: str | list[str] | None,
) -> nx.DiGraph[str]:
"""Build a directed acyclic graph (DAG) from functions.
Expand Down Expand Up @@ -224,7 +223,7 @@ def create_dag(

def _create_combined_function_from_dag( # noqa: PLR0913
dag: nx.DiGraph[str],
functions: dict[str, Callable[..., Any]] | list[Callable[..., Any]],
functions: Mapping[str, Callable[..., Any]] | list[Callable[..., Any]],
targets: str | list[str] | None,
return_type: Literal["tuple", "list", "dict"] = "tuple",
aggregator: Callable[[T, T], T] | None = None,
Expand Down Expand Up @@ -359,7 +358,7 @@ def _create_combined_function_from_dag( # noqa: PLR0913


def get_ancestors(
functions: dict[str, Callable[..., Any]] | list[Callable[..., Any]],
functions: Mapping[str, Callable[..., Any]] | list[Callable[..., Any]],
targets: str | list[str] | None,
*,
include_targets: bool = False,
Expand Down Expand Up @@ -389,14 +388,14 @@ def get_ancestors(

ancestors: set[str] = set()
for target in _targets:
ancestors |= nx.ancestors(dag, target)
ancestors |= nx.ancestors(dag, target) # type: ignore[invalid-argument-type]
if include_targets:
ancestors.add(target)
return ancestors


def harmonize_and_check_functions_and_targets(
functions: dict[str, Callable[..., Any]] | list[Callable[..., Any]],
functions: Mapping[str, Callable[..., Any]] | list[Callable[..., Any]],
targets: str | list[str] | None,
) -> tuple[dict[str, Callable[..., Any]], list[str]]:
"""Harmonize the type of specified functions and targets and do some checks.
Expand All @@ -423,14 +422,11 @@ def harmonize_and_check_functions_and_targets(


def _harmonize_functions(
functions: dict[str, Callable[..., Any]] | list[Callable[..., Any]],
functions: Mapping[str, Callable[..., Any]] | list[Callable[..., Any]],
) -> dict[str, Callable[..., Any]]:
if not isinstance(functions, dict):
functions_dict = {func.__name__: func for func in functions} # ty: ignore[unresolved-attribute]
else:
functions_dict = functions

return functions_dict
if isinstance(functions, Mapping):
return {k: v for k, v in functions.items()} # noqa: C416 # ty: ignore[invalid-return-type]
return {func.__name__: func for func in functions} # ty: ignore[unresolved-attribute]


def _harmonize_targets(
Expand Down Expand Up @@ -466,7 +462,7 @@ def _fail_if_functions_are_missing(

def _fail_if_dag_contains_cycle(dag: nx.DiGraph[str]) -> None:
"""Check for cycles in DAG."""
cycles = list(nx.simple_cycles(dag))
cycles = list(nx.simple_cycles(dag)) # type: ignore[invalid-argument-type]

if len(cycles) > 0:
formatted = format_list_linewise(cycles)
Expand Down Expand Up @@ -513,7 +509,7 @@ def _limit_dag_to_targets_and_their_ancestors(
"""
used_nodes = set(targets)
for target in targets:
used_nodes = used_nodes | set(nx.ancestors(dag, target))
used_nodes = used_nodes | set(nx.ancestors(dag, target)) # type: ignore[invalid-argument-type]

all_nodes = set(dag.nodes)

Expand Down Expand Up @@ -573,7 +569,7 @@ def create_execution_info(

"""
out = {}
for node in nx.lexicographical_topological_sort(dag, key=lexsort_key):
for node in nx.lexicographical_topological_sort(dag, key=lexsort_key): # type: ignore[invalid-argument-type]
if node in functions:
out[node] = FunctionExecutionInfo(
name=node,
Expand Down
29 changes: 16 additions & 13 deletions src/dags/signature.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import functools
import inspect
from collections.abc import Mapping
from typing import TYPE_CHECKING, Any, cast, overload

from dags.annotations import get_annotations
Expand Down Expand Up @@ -76,8 +77,8 @@ def _create_annotations(
def with_signature(
func: Callable[P, R],
*,
args: dict[str, str] | list[str] | None = None,
kwargs: dict[str, str] | list[str] | None = None,
args: Mapping[str, str] | list[str] | None = None,
kwargs: Mapping[str, str] | list[str] | None = None,
enforce: bool = True,
return_annotation: Any = inspect.Parameter.empty,
) -> Callable[P, R]: ...
Expand All @@ -86,8 +87,8 @@ def with_signature(
@overload
def with_signature(
*,
args: dict[str, str] | list[str] | None = None,
kwargs: dict[str, str] | list[str] | None = None,
args: Mapping[str, str] | list[str] | None = None,
kwargs: Mapping[str, str] | list[str] | None = None,
enforce: bool = True,
return_annotation: Any = inspect.Parameter.empty,
) -> Callable[[Callable[P, R]], Callable[P, R]]: ...
Expand All @@ -96,8 +97,8 @@ def with_signature(
def with_signature(
func: Callable[P, R] | None = None,
*,
args: dict[str, str] | list[str] | None = None,
kwargs: dict[str, str] | list[str] | None = None,
args: Mapping[str, str] | list[str] | None = None,
kwargs: Mapping[str, str] | list[str] | None = None,
enforce: bool = True,
return_annotation: Any = inspect.Parameter.empty,
) -> Callable[P, R] | Callable[[Callable[P, R]], Callable[P, R]]:
Expand Down Expand Up @@ -190,18 +191,18 @@ def _fail_if_invalid_keyword_arguments(
def rename_arguments(
func: Callable[P, R],
*,
mapper: dict[str, str],
mapper: Mapping[str, str],
) -> Callable[..., R]: ...


@overload
def rename_arguments(
*, mapper: dict[str, str]
*, mapper: Mapping[str, str]
) -> Callable[[Callable[P, R]], Callable[..., R]]: ...


def rename_arguments( # noqa: C901
func: Callable[P, R] | None = None, *, mapper: dict[str, str] | None = None
func: Callable[P, R] | None = None, *, mapper: Mapping[str, str] | None = None
) -> Callable[..., R] | Callable[[Callable[P, R]], Callable[..., R]]:
"""Rename positional and keyword arguments of func.

Expand Down Expand Up @@ -275,12 +276,14 @@ def wrapper_rename_arguments(*args: P.args, **kwargs: P.kwargs) -> R:


def _map_names_to_types(
arg: dict[str, str] | list[str] | None,
arg: Mapping[str, str] | list[str] | None,
) -> dict[str, str] | dict[str, type[inspect._empty]]:
if arg is None:
return {}
if isinstance(arg, list):
return dict.fromkeys(arg, inspect.Parameter.empty)
if isinstance(arg, dict):
return arg
raise DagsError(f"Invalid type for arg: {type(arg)}. Expected dict, list, or None.")
if isinstance(arg, Mapping):
return dict(arg)
raise DagsError(
f"Invalid type for arg: {type(arg)}. Expected Mapping, list, or None."
)
Loading