Source code for flytekit.core.testing
import typing
from contextlib import contextmanager
from typing import Union
from unittest.mock import MagicMock
from flytekit.core.base_task import PythonTask
from flytekit.core.reference_entity import ReferenceEntity
from flytekit.core.workflow import WorkflowBase
from flytekit.loggers import logger
[docs]
@contextmanager
def task_mock(t: PythonTask) -> typing.Generator[MagicMock, None, None]:
"""
Use this method to mock a task declaration. It can mock any Task in Flytekit as long as it has a python native
interface associated with it.
The returned object is a MagicMock and allows to perform all such methods. This MagicMock, mocks the execute method
on the PythonTask
Usage:
.. code-block:: python
@task
def t1(i: int) -> int:
pass
with task_mock(t1) as m:
m.side_effect = lambda x: x
t1(10)
# The mock is valid only within this context
"""
if not isinstance(t, PythonTask) and not isinstance(t, WorkflowBase) and not isinstance(t, ReferenceEntity):
raise ValueError(f"Can only be used for tasks, but got {type(t)}")
m = MagicMock()
def _log(*args, **kwargs):
logger.warning(f"Invoking mock method for task: '{t.name}'")
return m(*args, **kwargs)
_captured_fn = t.execute
t.execute = _log # type: ignore
yield m
t.execute = _captured_fn # type: ignore
[docs]
def patch(target: Union[PythonTask, WorkflowBase, ReferenceEntity]):
"""
This is a decorator used for testing.
"""
if (
not isinstance(target, PythonTask)
and not isinstance(target, WorkflowBase)
and not isinstance(target, ReferenceEntity)
):
raise ValueError(f"Can only use mocks on tasks/workflows declared in Python, but got {type(target)}")
logger.info(
"When using this patch function on Flyte entities, please be aware weird issues may arise if also"
"using mock.patch on internal Flyte classes like PythonFunctionWorkflow. See"
"https://github.com/flyteorg/flyte/issues/854 for more information"
)
def wrapper(test_fn):
def new_test(*args, **kwargs):
logger.warning(f"Invoking mock method for target: '{target.name}'")
m = MagicMock()
saved = target.execute
target.execute = m
results = test_fn(m, *args, **kwargs)
target.execute = saved
return results
return new_test
return wrapper