Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion flytekit/testing/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
"""
This module provides functionality related to testing
This module provides functionality related to testing.

Provides utilities for mocking tasks and workflows, pytest fixtures for common
test setup patterns, and helpers for local workflow execution.
"""

from flytekit.core.context_manager import SecretsManager
from flytekit.core.testing import patch, task_mock
from flytekit.testing.fixtures import flyte_cache, flyte_context, flyte_tmp_dir, workflow_dry_run
10 changes: 10 additions & 0 deletions flytekit/testing/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
"""Pytest plugin that auto-registers flytekit testing fixtures.

When ``flytekit`` is installed, these fixtures are automatically available in any
pytest session without needing to import them explicitly. This works via the
``pytest11`` entry point registered in ``pyproject.toml``.
"""

from flytekit.testing.fixtures import flyte_cache, flyte_context, flyte_tmp_dir

__all__ = ["flyte_cache", "flyte_context", "flyte_tmp_dir"]
96 changes: 96 additions & 0 deletions flytekit/testing/fixtures.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import tempfile
import typing
from contextlib import contextmanager
from pathlib import Path

import pytest

from flytekit.core.context_manager import FlyteContext, FlyteContextManager
from flytekit.core.local_cache import LocalTaskCache


@pytest.fixture
def flyte_context() -> FlyteContext:
"""Provide the current FlyteContext for testing.

This eliminates the need to manually call ``FlyteContextManager.current_context()``
in every test that needs a context for type transformations, file access, or other
context-dependent operations.

Usage::

def test_type_transform(flyte_context):
from flytekit.core.type_engine import TypeEngine
lt = TypeEngine.to_literal_type(int)
lv = TypeEngine.to_literal(flyte_context, 42, int, lt)
assert lv.scalar.primitive.integer == 42
"""
return FlyteContextManager.current_context()


@pytest.fixture
def flyte_cache():
"""Initialize and clear the local task cache before and after each test.

Prevents stale cached results from prior test runs from leaking into the current test.
This addresses a common pain point where ``cache=True`` on tasks causes flaky tests
because the on-disk cache (``~/.flyte/local-cache``) persists between test runs.

See https://github.com/flyteorg/flyte/issues/5657

Usage::

def test_cached_task(flyte_cache):
@task(cache=True, cache_version="v1")
def add(a: int, b: int) -> int:
return a + b

assert add(a=1, b=2) == 3
# Cache is automatically cleared after the test
"""
LocalTaskCache.initialize()
LocalTaskCache.clear()
yield
LocalTaskCache.clear()


@pytest.fixture
def flyte_tmp_dir() -> typing.Generator[Path, None, None]:
"""Provide a temporary directory that is cleaned up after the test.

Useful for tests involving ``FlyteFile``, ``FlyteDirectory``, or any operation
that needs to write files to disk.

Usage::

def test_file_output(flyte_tmp_dir):
output_path = flyte_tmp_dir / "result.txt"
output_path.write_text("hello")
assert output_path.read_text() == "hello"
"""
with tempfile.TemporaryDirectory() as td:
yield Path(td)


@contextmanager
def workflow_dry_run() -> typing.Generator[None, None, None]:
"""Context manager that sets up a clean local execution environment.

Initializes and clears the local cache, then cleans up after the block completes.
Useful for running a workflow locally in tests without worrying about cached state.

Usage::

from flytekit.testing.fixtures import workflow_dry_run

def test_my_workflow():
with workflow_dry_run():
result = my_workflow(x=1, y=2)
assert result == 3
"""
LocalTaskCache.initialize()
LocalTaskCache.clear()
try:
yield
finally:
LocalTaskCache.clear()
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ pyflyte-map-execute = "flytekit.bin.entrypoint:map_execute_task_cmd"
pyflyte = "flytekit.clis.sdk_in_container.pyflyte:main"
flyte-cli = "flytekit.clis.flyte_cli.main:_flyte_cli"

[project.entry-points.pytest11]
flytekit = "flytekit.testing.conftest"

[tool.setuptools_scm]
write_to = "flytekit/_version.py"

Expand Down
Empty file.
113 changes: 113 additions & 0 deletions tests/flytekit/unit/testing/test_fixtures.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import pytest

from flytekit import task, workflow
from flytekit.core.context_manager import FlyteContext
from flytekit.core.local_cache import LocalTaskCache
from flytekit.core.type_engine import TypeEngine
from flytekit.testing.fixtures import flyte_cache, flyte_context, flyte_tmp_dir, workflow_dry_run


class TestFlyteContextFixture:
def test_returns_flyte_context(self, flyte_context):
assert isinstance(flyte_context, FlyteContext)

def test_context_has_file_access(self, flyte_context):
assert flyte_context.file_access is not None

def test_type_transform_with_context(self, flyte_context):
lt = TypeEngine.to_literal_type(int)
lv = TypeEngine.to_literal(flyte_context, 42, int, lt)
assert lv.scalar.primitive.integer == 42


class TestFlyteCacheFixture:
def test_cache_is_cleared(self, flyte_cache):
assert LocalTaskCache._initialized is True

def test_cached_task_works(self, flyte_cache):
call_count = 0

@task(cache=True, cache_version="test-v1")
def add(a: int, b: int) -> int:
nonlocal call_count
call_count += 1
return a + b

result1 = add(a=1, b=2)
result2 = add(a=1, b=2)
assert result1 == 3
assert result2 == 3
assert call_count == 1 # second call should hit cache

def test_cache_isolated_between_tests_a(self, flyte_cache):
"""First test in a pair that verifies cache isolation."""

@task(cache=True, cache_version="isolation-v1")
def multiply(a: int, b: int) -> int:
return a * b

assert multiply(a=3, b=4) == 12

def test_cache_isolated_between_tests_b(self, flyte_cache):
"""Second test verifying the cache was cleared between tests."""
call_count = 0

@task(cache=True, cache_version="isolation-v1")
def multiply(a: int, b: int) -> int:
nonlocal call_count
call_count += 1
return a * b

multiply(a=3, b=4)
assert call_count == 1 # should NOT hit cache from previous test


class TestFlyteTmpDirFixture:
def test_provides_path(self, flyte_tmp_dir):
from pathlib import Path

assert isinstance(flyte_tmp_dir, Path)
assert flyte_tmp_dir.exists()
assert flyte_tmp_dir.is_dir()

def test_can_write_files(self, flyte_tmp_dir):
test_file = flyte_tmp_dir / "test.txt"
test_file.write_text("hello flytekit")
assert test_file.read_text() == "hello flytekit"

def test_can_create_subdirectories(self, flyte_tmp_dir):
sub = flyte_tmp_dir / "subdir"
sub.mkdir()
assert sub.exists()


class TestWorkflowDryRun:
def test_basic_workflow(self):
@task
def add_one(x: int) -> int:
return x + 1

@workflow
def simple_wf(x: int) -> int:
return add_one(x=x)

with workflow_dry_run():
result = simple_wf(x=5)
assert result == 6

def test_cached_workflow(self):
call_count = 0

@task(cache=True, cache_version="dry-run-v1")
def square(x: int) -> int:
nonlocal call_count
call_count += 1
return x * x

@workflow
def square_wf(x: int) -> int:
return square(x=x)

with workflow_dry_run():
result = square_wf(x=4)
assert result == 16