aidial_assistant/open_api/requester.py (82 lines of code) (raw):

import json import logging from typing import Dict, List, NamedTuple, Optional import aiohttp.client_exceptions from aiohttp import hdrs from langchain.tools.openapi.utils.api_models import APIOperation from aidial_assistant.commands.base import JsonResult, ResultObject, TextResult from aidial_assistant.utils.requests import arequest logger = logging.getLogger(__name__) class _ParamMapping(NamedTuple): """Mapping from parameter name to parameter value.""" query_params: List[str] body_params: List[str] path_params: List[str] class OpenAPIEndpointRequester: """Chain interacts with an OpenAPI endpoint using natural language. Based on OpenAPIEndpointChain from LangChain. """ def __init__(self, operation: APIOperation, plugin_auth: str | None): self.operation = operation self.param_mapping = _ParamMapping( query_params=operation.query_params, # type: ignore body_params=operation.body_params, # type: ignore path_params=operation.path_params, # type: ignore ) self.plugin_auth = plugin_auth def _construct_path(self, args: Dict[str, str]) -> str: """Construct the path from the deserialized input.""" path = self.operation.base_url.rstrip("/") + self.operation.path # type: ignore for param in self.param_mapping.path_params: path = path.replace(f"{{{param}}}", str(args.pop(param, ""))) return path def _extract_query_params(self, args: Dict[str, str]) -> Dict[str, str]: """Extract the query params from the deserialized input.""" query_params = {} for param in self.param_mapping.query_params: if param in args: query_params[param] = args.pop(param) return query_params def _extract_body_params( self, args: Dict[str, str] ) -> Optional[Dict[str, str]]: """Extract the request body params from the deserialized input.""" body_params = None if self.param_mapping.body_params: body_params = {} for param in self.param_mapping.body_params: if param in args: body_params[param] = args.pop(param) return body_params def deserialize_json_input(self, args: dict) -> dict: """Use the serialized typescript dictionary. Resolve the path, query params dict, and optional requestBody dict. """ path = self._construct_path(args) body_params = self._extract_body_params(args) query_params = self._extract_query_params(args) return { "url": path, "json": body_params, "params": query_params, } async def execute( self, args: dict, ) -> ResultObject: request_args = self.deserialize_json_input(args) headers = ( None if self.plugin_auth is None else {hdrs.AUTHORIZATION: self.plugin_auth} ) logger.debug(f"Request args: {request_args}") async with arequest( self.operation.method.value, headers=headers, **request_args # type: ignore ) as response: if response.status != 200: try: return JsonResult(json.dumps(await response.json())) except aiohttp.ContentTypeError: method_str = str(self.operation.method.value) # type: ignore error_object = { "reason": response.reason, "status_code": response.status, "method:": method_str.upper(), "url": request_args["url"], "params": request_args["params"], } return JsonResult(json.dumps(error_object)) if "text" in response.headers[hdrs.CONTENT_TYPE]: return TextResult(await response.text()) return JsonResult(json.dumps(await response.json()))