aidial_adapter_bedrock/llm/model/cohere.py (136 lines of code) (raw):

from typing import Any, AsyncIterator, Dict, List, Optional from aidial_sdk.chat_completion import Message from pydantic import BaseModel, Field from typing_extensions import override from aidial_adapter_bedrock.bedrock import ( Bedrock, ResponseWithInvocationMetricsMixin, ) from aidial_adapter_bedrock.dial_api.request import ModelParameters from aidial_adapter_bedrock.dial_api.token_usage import TokenUsage from aidial_adapter_bedrock.llm.chat_emulator import ( BasicChatEmulator, CueMapping, ) from aidial_adapter_bedrock.llm.chat_model import ( PseudoChatModel, trivial_partitioner, ) from aidial_adapter_bedrock.llm.consumer import Consumer from aidial_adapter_bedrock.llm.model.conf import DEFAULT_MAX_TOKENS_COHERE from aidial_adapter_bedrock.llm.tokenize import default_tokenize_string from aidial_adapter_bedrock.llm.tools.default_emulator import ( default_tools_emulator, ) from aidial_adapter_bedrock.utils.log_config import bedrock_logger as log class CohereResult(BaseModel): tokenCount: int outputText: str completionReason: Optional[str] class Likelihood(BaseModel): likelihood: float token: str class CohereGeneration(BaseModel): id: str text: str likelihood: float finish_reason: str token_likelihoods: List[Likelihood] = Field(repr=False) class CohereResponse(ResponseWithInvocationMetricsMixin): id: str prompt: Optional[str] generations: List[CohereGeneration] def content(self) -> str: return self.generations[0].text @property def tokens(self) -> List[str]: """Includes prompt and completion tokens""" return [lh.token for lh in self.generations[0].token_likelihoods] def usage_by_tokens(self) -> TokenUsage: special_tokens = 7 total_tokens = len(self.tokens) - special_tokens # The structure for the response: # ["<BOS_TOKEN>", "User", ":", *<prompt>, "\n", "Chat", "bot", ":", "<EOP_TOKEN>", *<completion>] # prompt_tokens = len(<prompt>) # completion_tokens = len(["<EOP_TOKEN>"] + <completion>) separator = "<EOP_TOKEN>" if separator in self.tokens: prompt_tokens = self.tokens.index(separator) - special_tokens else: log.error(f"Separator '{separator}' not found in tokens") prompt_tokens = total_tokens // 2 return TokenUsage( prompt_tokens=prompt_tokens, completion_tokens=total_tokens - prompt_tokens, ) def convert_params(params: ModelParameters) -> Dict[str, Any]: ret = {} if params.temperature is not None: ret["temperature"] = params.temperature if params.max_tokens is not None: ret["max_tokens"] = params.max_tokens else: # Choosing reasonable default ret["max_tokens"] = DEFAULT_MAX_TOKENS_COHERE ret["return_likelihoods"] = "ALL" # NOTE: num_generations is supported return ret def create_request(prompt: str, params: Dict[str, Any]) -> Dict[str, Any]: return {"prompt": prompt, **params} async def chunks_to_stream( chunks: AsyncIterator[dict], usage: TokenUsage ) -> AsyncIterator[str]: async for chunk in chunks: resp = CohereResponse.parse_obj(chunk) usage.accumulate(resp.usage_by_metrics()) log.debug(f"tokens: {'|'.join(resp.tokens)!r}") yield resp.content() async def response_to_stream( response: dict, usage: TokenUsage ) -> AsyncIterator[str]: resp = CohereResponse.parse_obj(response) usage.accumulate(resp.usage_by_tokens()) log.debug(f"tokens: {'|'.join(resp.tokens)!r}") yield resp.content() cohere_emulator = BasicChatEmulator( prelude_template=None, add_cue=lambda _, idx: idx > 0, add_invitation_cue=False, fallback_to_completion=False, cues=CueMapping( system="User:", human="User:", ai="Chatbot:", ), separator="\n", ) class CohereAdapter(PseudoChatModel): model: str client: Bedrock @classmethod def create(cls, client: Bedrock, model: str): return cls( client=client, model=model, tokenize_string=default_tokenize_string, chat_emulator=cohere_emulator, tools_emulator=default_tools_emulator, partitioner=trivial_partitioner, ) @override def preprocess_messages(self, messages: List[Message]) -> List[Message]: messages = super().preprocess_messages(messages) # Cohere doesn't support empty messages, # so replace it with a single space. for msg in messages: msg.content = msg.content or " " return messages async def predict( self, consumer: Consumer, params: ModelParameters, prompt: str ): args = create_request(prompt, convert_params(params)) usage = TokenUsage() if params.stream: chunks = self.client.ainvoke_streaming(self.model, args) stream = chunks_to_stream(chunks, usage) else: response, _headers = await self.client.ainvoke_non_streaming( self.model, args ) stream = response_to_stream(response, usage) stream = self.post_process_stream(stream, params, self.chat_emulator) async for content in stream: consumer.append_content(content) consumer.close_content() consumer.add_usage(usage)