Source code for flytekitplugins.snowflake.agent

import json
from dataclasses import asdict, dataclass
from typing import Optional

from flyteidl.admin.agent_pb2 import (
    CreateTaskResponse,
    DeleteTaskResponse,
    GetTaskResponse,
    Resource,
)
from flyteidl.core.execution_pb2 import TaskExecution

from flytekit import FlyteContextManager, StructuredDataset, lazy_module, logger
from flytekit.core.type_engine import TypeEngine
from flytekit.extend.backend.base_agent import AgentBase, AgentRegistry, convert_to_flyte_phase
from flytekit.models import literals
from flytekit.models.literals import LiteralMap
from flytekit.models.task import TaskTemplate
from flytekit.models.types import LiteralType, StructuredDatasetType

snowflake_connector = lazy_module("snowflake.connector")

TASK_TYPE = "snowflake"
SNOWFLAKE_PRIVATE_KEY = "snowflake_private_key"


@dataclass
class Metadata:
    user: str
    account: str
    database: str
    schema: str
    warehouse: str
    table: str
    query_id: str


def get_private_key():
    from cryptography.hazmat.backends import default_backend
    from cryptography.hazmat.primitives import serialization

    import flytekit

    pk_string = flytekit.current_context().secrets.get(SNOWFLAKE_PRIVATE_KEY, encode_mode="rb")
    p_key = serialization.load_pem_private_key(pk_string, password=None, backend=default_backend())

    pkb = p_key.private_bytes(
        encoding=serialization.Encoding.DER,
        format=serialization.PrivateFormat.PKCS8,
        encryption_algorithm=serialization.NoEncryption(),
    )

    return pkb


def get_connection(metadata: Metadata) -> snowflake_connector:
    return snowflake_connector.connect(
        user=metadata.user,
        account=metadata.account,
        private_key=get_private_key(),
        database=metadata.database,
        schema=metadata.schema,
        warehouse=metadata.warehouse,
    )


[docs]class SnowflakeAgent(AgentBase): def __init__(self): super().__init__(task_type=TASK_TYPE)
[docs] async def create( self, output_prefix: str, task_template: TaskTemplate, inputs: Optional[LiteralMap] = None, **kwargs ) -> CreateTaskResponse: params = None if inputs: ctx = FlyteContextManager.current_context() python_interface_inputs = { name: TypeEngine.guess_python_type(lt.type) for name, lt in task_template.interface.inputs.items() } native_inputs = TypeEngine.literal_map_to_kwargs(ctx, inputs, python_interface_inputs) logger.info(f"Create Snowflake agent params with inputs: {native_inputs}") params = native_inputs config = task_template.config conn = snowflake_connector.connect( user=config["user"], account=config["account"], private_key=get_private_key(), database=config["database"], schema=config["schema"], warehouse=config["warehouse"], ) cs = conn.cursor() cs.execute_async(task_template.sql.statement, params=params) metadata = Metadata( user=config["user"], account=config["account"], database=config["database"], schema=config["schema"], warehouse=config["warehouse"], table=config["table"], query_id=str(cs.sfqid), ) return CreateTaskResponse(resource_meta=json.dumps(asdict(metadata)).encode("utf-8"))
[docs] async def get(self, resource_meta: bytes, **kwargs) -> GetTaskResponse: metadata = Metadata(**json.loads(resource_meta.decode("utf-8"))) conn = get_connection(metadata) try: query_status = conn.get_query_status_throw_if_error(metadata.query_id) except snowflake_connector.ProgrammingError as err: logger.error("Failed to get snowflake job status with error:", err.msg) return GetTaskResponse(resource=Resource(state=TaskExecution.FAILED)) cur_phase = convert_to_flyte_phase(str(query_status.name)) res = None if cur_phase == TaskExecution.SUCCEEDED: ctx = FlyteContextManager.current_context() output_metadata = f"snowflake://{metadata.user}:{metadata.account}/{metadata.warehouse}/{metadata.database}/{metadata.schema}/{metadata.table}" res = literals.LiteralMap( { "results": TypeEngine.to_literal( ctx, StructuredDataset(uri=output_metadata), StructuredDataset, LiteralType(structured_dataset_type=StructuredDatasetType(format="")), ) } ).to_flyte_idl() return GetTaskResponse(resource=Resource(phase=cur_phase, outputs=res))
[docs] async def delete(self, resource_meta: bytes, **kwargs) -> DeleteTaskResponse: metadata = Metadata(**json.loads(resource_meta.decode("utf-8"))) conn = get_connection(metadata) cs = conn.cursor() try: cs.execute(f"SELECT SYSTEM$CANCEL_QUERY('{metadata.query_id}')") cs.fetchall() finally: cs.close() conn.close() return DeleteTaskResponse()
AgentRegistry.register(SnowflakeAgent())