Source code for flytekit.core.testing

import typing
from contextlib import contextmanager
from typing import Union
from unittest.mock import MagicMock

from flytekit.core.base_task import PythonTask
from flytekit.core.reference_entity import ReferenceEntity
from flytekit.core.workflow import WorkflowBase
from flytekit.loggers import logger


[docs] @contextmanager def task_mock(t: PythonTask) -> typing.Generator[MagicMock, None, None]: """ Use this method to mock a task declaration. It can mock any Task in Flytekit as long as it has a python native interface associated with it. The returned object is a MagicMock and allows to perform all such methods. This MagicMock, mocks the execute method on the PythonTask Usage: .. code-block:: python @task def t1(i: int) -> int: pass with task_mock(t1) as m: m.side_effect = lambda x: x t1(10) # The mock is valid only within this context """ if not isinstance(t, PythonTask) and not isinstance(t, WorkflowBase) and not isinstance(t, ReferenceEntity): raise ValueError(f"Can only be used for tasks, but got {type(t)}") m = MagicMock() def _log(*args, **kwargs): logger.warning(f"Invoking mock method for task: '{t.name}'") return m(*args, **kwargs) _captured_fn = t.execute t.execute = _log # type: ignore yield m t.execute = _captured_fn # type: ignore
[docs] def patch(target: Union[PythonTask, WorkflowBase, ReferenceEntity]): """ This is a decorator used for testing. """ if ( not isinstance(target, PythonTask) and not isinstance(target, WorkflowBase) and not isinstance(target, ReferenceEntity) ): raise ValueError(f"Can only use mocks on tasks/workflows declared in Python, but got {type(target)}") logger.info( "When using this patch function on Flyte entities, please be aware weird issues may arise if also" "using mock.patch on internal Flyte classes like PythonFunctionWorkflow. See" "https://github.com/flyteorg/flyte/issues/854 for more information" ) def wrapper(test_fn): def new_test(*args, **kwargs): logger.warning(f"Invoking mock method for target: '{target.name}'") m = MagicMock() saved = target.execute target.execute = m results = test_fn(m, *args, **kwargs) target.execute = saved return results return new_test return wrapper