aidial_assistant/utils/open_ai_plugin.py (81 lines of code) (raw):
import logging
from typing import Iterable, Mapping
from urllib.parse import urljoin
from aiocache import cached
from aiohttp import hdrs
from fastapi import HTTPException
from langchain.tools import OpenAPISpec
from pydantic import BaseModel, parse_obj_as
from starlette.status import HTTP_401_UNAUTHORIZED
from aidial_assistant.utils.requests import aget
logger = logging.getLogger(__name__)
class AuthConf(BaseModel):
type: str
authorization_type: str = "bearer"
class ApiConf(BaseModel):
type: str
url: str
has_user_authentication: bool = False
is_user_authenticated: bool = False
class AIPluginConf(BaseModel):
schema_version: str
name_for_model: str
name_for_human: str
description_for_model: str
description_for_human: str
auth: AuthConf
api: ApiConf
logo_url: str
contact_email: str
legal_info_url: str
class OpenAIPluginInfo(BaseModel):
ai_plugin: AIPluginConf
open_api: OpenAPISpec
class AddonTokenSource:
def __init__(self, headers: Mapping[str, str], urls: Iterable[str]):
self.headers = headers
self.urls = {
url: f"x-addon-token-{index}" for index, url in enumerate(urls)
}
def get_token(self, url: str) -> str | None:
return self.headers.get(self.urls[url])
@property
def default_auth(self) -> str | None:
return self.headers.get(hdrs.AUTHORIZATION)
def get_plugin_auth(
auth_type: str,
authorization_type: str,
url: str,
token_source: AddonTokenSource,
) -> str | None:
if auth_type == "none":
return token_source.default_auth
if auth_type == "service_http":
service_token = token_source.get_token(url)
if service_token is None:
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail=f"Missing token for {url}",
)
# Capitalizing because Wolfram, for instance, doesn't like lowercase bearer
return f"{authorization_type.capitalize()} {service_token}"
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail=f"Unknown auth type {auth_type}",
)
async def get_open_ai_plugin_info(addon_url: str) -> OpenAIPluginInfo:
"""Takes url pointing to .well-known/ai-plugin.json file"""
logger.info(f"Fetching plugin info from {addon_url}")
ai_plugin = await _parse_ai_plugin_conf(addon_url)
# Resolve relative url
ai_plugin.api.url = urljoin(addon_url, ai_plugin.api.url)
logger.info(f"Fetching plugin spec from {ai_plugin.api.url}")
open_api = await _parse_openapi_spec(ai_plugin.api.url)
return OpenAIPluginInfo(ai_plugin=ai_plugin, open_api=open_api)
@cached()
async def _parse_ai_plugin_conf(url: str) -> AIPluginConf:
async with aget(url) as response:
# content_type=None to disable validation, sometimes response comes as text/json
return parse_obj_as(
AIPluginConf, await response.json(content_type=None)
)
@cached()
async def _parse_openapi_spec(url: str) -> OpenAPISpec:
async with aget(url) as response:
return OpenAPISpec.from_text(await response.text()) # type: ignore