aidial_adapter_openai/dial_api/embedding_inputs.py (70 lines of code) (raw):

from typing import ( Any, AsyncIterator, Callable, Coroutine, List, TypeVar, assert_never, cast, ) from aidial_sdk.chat_completion.request import Attachment from aidial_sdk.embeddings.request import EmbeddingsRequest from aidial_sdk.exceptions import RequestValidationError _T = TypeVar("_T") _Coro = Coroutine[_T, Any, Any] _Tokens = List[int] async def reject_tokens(tokens: _Tokens): raise RequestValidationError( "Tokens in an embedding input are not supported. Provide text instead. " "When Langchain AzureOpenAIEmbeddings class is used, set 'check_embedding_ctx_length=False' to disable tokenization." ) async def reject_mixed(input: List[str | Attachment]): raise RequestValidationError( "Embedding inputs composed of multiple texts and/or attachments aren't supported" ) async def collect_embedding_inputs( request: EmbeddingsRequest, *, on_text: Callable[[str], _Coro[_T]], on_attachment: Callable[[Attachment], _Coro[_T]], on_tokens: Callable[[_Tokens], _Coro[_T]] = reject_tokens, on_mixed: Callable[[List[str | Attachment]], _Coro[_T]] = reject_mixed, ) -> AsyncIterator[_T]: async def _on_str_or_attachment(input: str | Attachment) -> _T: if isinstance(input, str): return await on_text(input) elif isinstance(input, Attachment): return await on_attachment(input) else: assert_never(input) if isinstance(request.input, str): yield await on_text(request.input) elif isinstance(request.input, list): is_list_of_tokens = False for input in request.input: if isinstance(input, str): yield await on_text(input) elif isinstance(input, list): yield await on_tokens(input) else: is_list_of_tokens = True break if is_list_of_tokens: yield await on_tokens(cast(_Tokens, request.input)) else: assert_never(request.input) if request.custom_input is None: return for input in request.custom_input: if isinstance(input, (str, Attachment)): yield await _on_str_or_attachment(input) elif isinstance(input, list): if len(input) == 0: pass elif len(input) == 1: yield await _on_str_or_attachment(input[0]) else: yield await on_mixed(input) else: assert_never(input)