aidial_adapter_vertexai/app.py (49 lines of code) (raw):
from contextlib import asynccontextmanager
import vertexai
from aidial_sdk import DIALApp
from aidial_sdk.telemetry.types import TelemetryConfig
from google.genai.client import Client as GenAIClient
from aidial_adapter_vertexai.chat_completion import VertexAIChatCompletion
from aidial_adapter_vertexai.deployments import (
ChatCompletionDeployment,
EmbeddingsDeployment,
)
from aidial_adapter_vertexai.dial_api.exceptions import dial_exception_decorator
from aidial_adapter_vertexai.dial_api.response import (
ModelObject,
ModelsResponse,
)
from aidial_adapter_vertexai.embeddings import VertexAIEmbeddings
from aidial_adapter_vertexai.utils.env import get_env
from aidial_adapter_vertexai.utils.log_config import configure_loggers
DEFAULT_REGION = get_env("DEFAULT_REGION")
GCP_PROJECT_ID = get_env("GCP_PROJECT_ID")
@asynccontextmanager
async def lifespan(app: DIALApp):
vertexai.init(project=GCP_PROJECT_ID, location=DEFAULT_REGION)
yield
app = DIALApp(
description="Google VertexAI adapter for DIAL API",
telemetry_config=TelemetryConfig(),
add_healthcheck=True,
lifespan=lifespan,
)
# NOTE: configuring logger after the DIAL telemetry is initialized,
# because it may have configured the root logger on its own.
configure_loggers()
@app.get("/openai/models")
@dial_exception_decorator
async def models():
models = [
ModelObject(id=model.value, object="model")
for model in ChatCompletionDeployment
]
return ModelsResponse(data=models)
genai_client = GenAIClient(
vertexai=True, project=GCP_PROJECT_ID, location=DEFAULT_REGION
)
for deployment in ChatCompletionDeployment:
app.add_chat_completion(
deployment.get_model_id(),
VertexAIChatCompletion(client=genai_client),
)
for deployment in EmbeddingsDeployment:
app.add_embeddings(deployment.get_model_id(), VertexAIEmbeddings())