Source code for flytekitplugins.snowflake.agent

from dataclasses import dataclass
from typing import Optional

from flyteidl.core.execution_pb2 import TaskExecution, TaskLog

from flytekit import FlyteContextManager, StructuredDataset, logger
from flytekit.core.type_engine import TypeEngine
from flytekit.extend.backend.base_agent import AgentRegistry, AsyncAgentBase, Resource, ResourceMeta
from flytekit.extend.backend.utils import convert_to_flyte_phase, get_agent_secret
from flytekit.models.literals import LiteralMap
from flytekit.models.task import TaskTemplate
from flytekit.models.types import LiteralType, StructuredDatasetType
from snowflake import connector as sc

TASK_TYPE = "snowflake"
SNOWFLAKE_PRIVATE_KEY = "snowflake_private_key"


@dataclass
class SnowflakeJobMetadata(ResourceMeta):
    user: str
    account: str
    database: str
    schema: str
    warehouse: str
    query_id: str
    has_output: bool


def get_private_key():
    from cryptography.hazmat.backends import default_backend
    from cryptography.hazmat.primitives import serialization

    pk_string = get_agent_secret(SNOWFLAKE_PRIVATE_KEY)
    # cryptography needs str to be stripped and converted to bytes
    pk_string = pk_string.strip().encode()
    p_key = serialization.load_pem_private_key(pk_string, password=None, backend=default_backend())

    pkb = p_key.private_bytes(
        encoding=serialization.Encoding.DER,
        format=serialization.PrivateFormat.PKCS8,
        encryption_algorithm=serialization.NoEncryption(),
    )

    return pkb


def get_connection(metadata: SnowflakeJobMetadata) -> sc:
    return sc.connect(
        user=metadata.user,
        account=metadata.account,
        private_key=get_private_key(),
        database=metadata.database,
        schema=metadata.schema,
        warehouse=metadata.warehouse,
    )


[docs] class SnowflakeAgent(AsyncAgentBase): name = "Snowflake Agent" def __init__(self): super().__init__(task_type_name=TASK_TYPE, metadata_type=SnowflakeJobMetadata)
[docs] async def create( self, task_template: TaskTemplate, inputs: Optional[LiteralMap] = None, **kwargs ) -> SnowflakeJobMetadata: ctx = FlyteContextManager.current_context() literal_types = task_template.interface.inputs params = TypeEngine.literal_map_to_kwargs(ctx, inputs, literal_types=literal_types) if inputs.literals else None config = task_template.config conn = sc.connect( user=config["user"], account=config["account"], private_key=get_private_key(), database=config["database"], schema=config["schema"], warehouse=config["warehouse"], ) cs = conn.cursor() cs.execute_async(task_template.sql.statement, params) return SnowflakeJobMetadata( user=config["user"], account=config["account"], database=config["database"], schema=config["schema"], warehouse=config["warehouse"], query_id=cs.sfqid, has_output=task_template.interface.outputs is not None and len(task_template.interface.outputs) > 0, )
[docs] async def get(self, resource_meta: SnowflakeJobMetadata, **kwargs) -> Resource: conn = get_connection(resource_meta) try: query_status = conn.get_query_status_throw_if_error(resource_meta.query_id) except sc.ProgrammingError as err: logger.error("Failed to get snowflake job status with error:", err.msg) return Resource(phase=TaskExecution.FAILED) log_link = TaskLog( uri=construct_query_link(resource_meta=resource_meta), name="Snowflake Query Details", ) # The snowflake job's state is determined by query status. # https://github.com/snowflakedb/snowflake-connector-python/blob/main/src/snowflake/connector/constants.py#L373 cur_phase = convert_to_flyte_phase(str(query_status.name)) res = None if cur_phase == TaskExecution.SUCCEEDED and resource_meta.has_output: ctx = FlyteContextManager.current_context() uri = f"snowflake://{resource_meta.user}:{resource_meta.account}/{resource_meta.warehouse}/{resource_meta.database}/{resource_meta.schema}/{resource_meta.query_id}" res = { "results": TypeEngine.to_literal( ctx, StructuredDataset(uri=uri), StructuredDataset, LiteralType(structured_dataset_type=StructuredDatasetType(format="")), ) } return Resource(phase=cur_phase, outputs=res, log_links=[log_link])
[docs] async def delete(self, resource_meta: SnowflakeJobMetadata, **kwargs): conn = get_connection(resource_meta) cs = conn.cursor() try: cs.execute(f"SELECT SYSTEM$CANCEL_QUERY('{resource_meta.query_id}')") cs.fetchall() finally: cs.close() conn.close()
def construct_query_link(resource_meta: SnowflakeJobMetadata) -> str: base_url = "https://app.snowflake.com" # Extract the account and region (assuming the format is account-region, you might need to adjust this based on your actual account format) account_parts = resource_meta.account.split("-") account = account_parts[0] region = account_parts[1] if len(account_parts) > 1 else "" url = f"{base_url}/{region}/{account}/#/compute/history/queries/{resource_meta.query_id}/detail" return url AgentRegistry.register(SnowflakeAgent())