Source code for flytekitplugins.openai.batch.agent

from dataclasses import dataclass
from enum import Enum
from typing import Dict, Optional

import cloudpickle

from flytekit import FlyteContextManager, lazy_module
from flytekit.core.type_engine import TypeEngine
from flytekit.extend.backend.base_agent import (
    AgentRegistry,
    AsyncAgentBase,
    Resource,
    ResourceMeta,
)
from flytekit.extend.backend.utils import convert_to_flyte_phase, get_agent_secret
from flytekit.models.literals import LiteralMap
from flytekit.models.task import TaskTemplate

openai = lazy_module("openai")
OPENAI_API_KEY = "FLYTE_OPENAI_API_KEY"


class State(Enum):
    Running = ["in_progress", "finalizing", "validating"]
    Success = ["completed"]
    Failed = ["failed", "cancelled", "cancelling", "expired"]

    @classmethod
    def key_by_value(cls, value) -> str:
        for member in cls:
            if value in member.value:
                return member.name


@dataclass
class BatchEndpointMetadata(ResourceMeta):
    openai_org: str
    batch_id: str

    def encode(self) -> bytes:
        return cloudpickle.dumps(self)

    @classmethod
    def decode(cls, data: bytes) -> "BatchEndpointMetadata":
        return cloudpickle.loads(data)


[docs] class BatchEndpointAgent(AsyncAgentBase): name = "OpenAI Batch Endpoint Agent" def __init__(self): super().__init__(task_type_name="openai-batch", metadata_type=BatchEndpointMetadata)
[docs] async def create( self, task_template: TaskTemplate, inputs: Optional[LiteralMap] = None, **kwargs, ) -> BatchEndpointMetadata: ctx = FlyteContextManager.current_context() input_values = TypeEngine.literal_map_to_kwargs( ctx, inputs, {"input_file_id": str}, ) custom = task_template.custom async_client = openai.AsyncOpenAI( organization=custom.get("openai_organization"), api_key=get_agent_secret(secret_key=OPENAI_API_KEY), ) custom["config"].setdefault("completion_window", "24h") custom["config"].setdefault("endpoint", "/v1/chat/completions") result = await async_client.batches.create( **custom["config"], input_file_id=input_values["input_file_id"], ) batch_id = result.id return BatchEndpointMetadata(batch_id=batch_id, openai_org=custom["openai_organization"])
[docs] async def get( self, resource_meta: BatchEndpointMetadata, **kwargs, ) -> Resource: async_client = openai.AsyncOpenAI( organization=resource_meta.openai_org, api_key=get_agent_secret(secret_key=OPENAI_API_KEY), ) retrieved_result = await async_client.batches.retrieve(resource_meta.batch_id) current_state = retrieved_result.status flyte_phase = convert_to_flyte_phase(State.key_by_value(current_state)) message = None if current_state in State.Failed.value and retrieved_result.errors: data = retrieved_result.errors.data if data and data[0].message: message = data[0].message result = retrieved_result.to_dict() ctx = FlyteContextManager.current_context() outputs = {"result": TypeEngine.to_literal(ctx, result, Dict, TypeEngine.to_literal_type(Dict))} return Resource(phase=flyte_phase, outputs=outputs, message=message)
[docs] async def delete( self, resource_meta: BatchEndpointMetadata, **kwargs, ): async_client = openai.AsyncOpenAI( organization=resource_meta.openai_org, api_key=get_agent_secret(secret_key=OPENAI_API_KEY), ) await async_client.batches.cancel(resource_meta.batch_id)
AgentRegistry.register(BatchEndpointAgent())