Source code for flytekitplugins.pandera.pandas_transformer
import typing
from typing import TYPE_CHECKING, Type, Union
from flytekit import Deck, FlyteContext, lazy_module
from flytekit.extend import TypeEngine, TypeTransformer
from flytekit.loggers import logger
from flytekit.models.literals import Literal
from flytekit.models.types import LiteralType, SchemaType
from flytekit.types.structured import StructuredDataset
from flytekit.types.structured.structured_dataset import StructuredDatasetTransformerEngine, get_supported_types
from .config import ValidationConfig
from .pandas_renderer import PandasReportRenderer
if TYPE_CHECKING:
import pandas
import pandera
else:
pandas = lazy_module("pandas")
pandera = lazy_module("pandera")
T = typing.TypeVar("T")
[docs]
class PanderaPandasTransformer(TypeTransformer[pandera.typing.DataFrame]):
_SUPPORTED_TYPES: typing.Dict[type, SchemaType.SchemaColumn.SchemaColumnType] = get_supported_types()
_VALIDATION_MEMO = set()
def __init__(self):
super().__init__("Pandera Transformer", pandera.typing.DataFrame) # type: ignore
self._sd_transformer = StructuredDatasetTransformerEngine()
def _get_pandera_schema(self, t: Type[pandera.typing.DataFrame]):
config = ValidationConfig()
if typing.get_origin(t) is typing.Annotated:
t, *args = typing.get_args(t)
# get pandera config
for arg in args:
if isinstance(arg, ValidationConfig):
config = arg
break
type_args = typing.get_args(t)
if type_args:
schema_model, *_ = type_args
schema = schema_model.to_schema()
else:
schema = pandera.DataFrameSchema() # type: ignore
return schema, config
@staticmethod
def _get_pandas_type(pandera_dtype: pandera.dtypes.DataType):
return pandera_dtype.type.type
def _get_col_dtypes(self, t: Type[pandera.typing.DataFrame]):
schema, _ = self._get_pandera_schema(t)
return {k: self._get_pandas_type(v.dtype) for k, v in schema.columns.items()}
[docs]
def get_literal_type(self, t: Type[pandera.typing.DataFrame]) -> LiteralType:
if typing.get_origin(t) is typing.Annotated:
t, _ = typing.get_args(t)
return self._sd_transformer.get_literal_type(t)
[docs]
def assert_type(self, t: Type[T], v: T):
if not hasattr(t, "__origin__") and not isinstance(v, (t, pandas.DataFrame)):
raise TypeError(f"Type of Val '{v}' is not an instance of {t}")
[docs]
def to_literal(
self,
ctx: FlyteContext,
python_val: Union[pandas.DataFrame, StructuredDataset],
python_type: Type[pandera.typing.DataFrame],
expected: LiteralType,
) -> Literal:
assert isinstance(
python_val, (pandas.DataFrame, StructuredDataset)
), f"Only Pandas Dataframe object can be returned from a task, returned object type {type(python_val)}"
if isinstance(python_val, StructuredDataset):
lv = self._sd_transformer.to_literal(ctx, python_val, pandas.DataFrame, expected)
python_val = self._sd_transformer.to_python_value(ctx, lv, pandas.DataFrame)
schema, config = self._get_pandera_schema(python_type)
renderer = PandasReportRenderer(title=f"Pandera Report: {schema.name}")
try:
val = schema.validate(python_val, lazy=True)
except (pandera.errors.SchemaError, pandera.errors.SchemaErrors) as exc:
html = renderer.to_html(python_val, schema, exc)
val = python_val
if config.on_error == "raise":
# render the deck before raising the error
raise exc
elif config.on_error == "warn":
logger.warning(str(exc))
else:
raise ValueError(f"Invalid on_error value: {config.on_error}")
else:
html = renderer.to_html(val, schema)
finally:
Deck(renderer._title, html)
lv = self._sd_transformer.to_literal(ctx, val, pandas.DataFrame, expected)
# In cases where a task is being called locally, this method will be invoked to convert the python input value
# to a Flyte literal, which is then deserialized back to a python value. In such cases, we can cache the
# structured dataset uri and schema name to avoid repeating the validation process in the subsequent
# to_python_value call.
self._VALIDATION_MEMO.add((lv.scalar.structured_dataset.uri, schema.name))
return lv
[docs]
def to_python_value(
self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[pandas.DataFrame]
) -> pandera.typing.DataFrame:
if not (lv and lv.scalar and lv.scalar.structured_dataset):
raise AssertionError("Can only convert a literal structured dataset to a pandera schema")
df = self._sd_transformer.to_python_value(ctx, lv, pandas.DataFrame)
schema, config = self._get_pandera_schema(expected_python_type)
if (lv.scalar.structured_dataset.uri, schema.name) in self._VALIDATION_MEMO:
return df
renderer = PandasReportRenderer(title=f"Pandera Report: {schema.name}")
try:
val = schema.validate(df, lazy=True)
except (pandera.errors.SchemaError, pandera.errors.SchemaErrors) as exc:
html = renderer.to_html(df, schema, exc)
val = df
if config.on_error == "raise":
raise exc
elif config.on_error == "warn":
logger.warning(str(exc))
else:
raise ValueError(f"Invalid on_error value: {config.on_error}")
else:
html = renderer.to_html(val, schema)
finally:
Deck(renderer._title, html)
return val
TypeEngine.register(PanderaPandasTransformer())