aidial_interceptors_sdk/embeddings/base.py (35 lines of code) (raw):
import logging
from abc import ABC
from typing import List
from aidial_sdk.pydantic_v1 import BaseModel
from aidial_interceptors_sdk.dial_client import DialClient
_log = logging.getLogger(__name__)
class EmbeddingsInterceptor(ABC, BaseModel):
class Config:
arbitrary_types_allowed = True
dial_client: DialClient
async def modify_input(self, input: str) -> str:
return input
async def modify_embedding(
self, embedding: str | List[float]
) -> str | List[float]:
return embedding
async def modify_request(self, request: dict) -> dict:
if "input" in request:
input = request["input"]
if isinstance(input, str):
request["input"] = await self.modify_input(input)
elif isinstance(input, list):
if all(isinstance(item, str) for item in input):
request["input"] = [
await self.modify_input(item) for item in input
]
else:
_log.warning("Tokenized input isn't yet supported")
return request
async def modify_response(self, response: dict) -> dict:
for item in response.get("data") or []:
item["embedding"] = await self.modify_embedding(item["embedding"])
return response
class EmbeddingsNoOpInterceptor(EmbeddingsInterceptor):
pass