aidial_adapter_bedrock/llm/model/ai21.py (92 lines of code) (raw):
from typing import Any, Dict, List, Optional
from pydantic import BaseModel
from aidial_adapter_bedrock.bedrock import Bedrock
from aidial_adapter_bedrock.dial_api.request import ModelParameters
from aidial_adapter_bedrock.dial_api.token_usage import TokenUsage
from aidial_adapter_bedrock.llm.chat_emulator import default_emulator
from aidial_adapter_bedrock.llm.chat_model import (
PseudoChatModel,
trivial_partitioner,
)
from aidial_adapter_bedrock.llm.consumer import Consumer
from aidial_adapter_bedrock.llm.model.conf import DEFAULT_MAX_TOKENS_AI21
from aidial_adapter_bedrock.llm.tokenize import default_tokenize_string
from aidial_adapter_bedrock.llm.tools.default_emulator import (
default_tools_emulator,
)
class TextRange(BaseModel):
start: int
end: int
class GeneratedToken(BaseModel):
token: str
logprob: float
raw_logprob: float
class Token(BaseModel):
generatedToken: GeneratedToken
topTokens: Optional[Any]
textRange: TextRange
class TextAndTokens(BaseModel):
text: str
tokens: List[Token]
class FinishReason(BaseModel):
reason: str # Literal["length", "endoftext"]
length: Optional[int]
class Completion(BaseModel):
data: TextAndTokens
finishReason: FinishReason
class AI21Response(BaseModel):
id: int
prompt: TextAndTokens
completions: List[Completion]
def content(self) -> str:
assert (
len(self.completions) == 1
), "AI21Response should only have one completion"
return self.completions[0].data.text
def usage(self) -> TokenUsage:
assert (
len(self.completions) == 1
), "AI21Response should only have one completion"
return TokenUsage(
prompt_tokens=len(self.prompt.tokens),
completion_tokens=len(self.completions[0].data.tokens),
)
# NOTE: See https://docs.ai21.com/reference/j2-instruct-ref
def convert_params(params: ModelParameters) -> Dict[str, Any]:
ret = {}
if params.max_tokens is not None:
ret["maxTokens"] = params.max_tokens
else:
# The default for max tokens is 16, which is too small for most use cases.
# Choosing reasonable default.
ret["maxTokens"] = DEFAULT_MAX_TOKENS_AI21
if params.temperature is not None:
# AI21 temperature ranges from 0.0 to 1.0
# OpenAI temperature ranges from 0.0 to 2.0
# Thus scaling down by 2x to match the AI21 range
ret["temperature"] = params.temperature / 2.0
if params.top_p is not None:
ret["topP"] = params.top_p
if params.stop:
ret["stopSequences"] = params.stop
# NOTE: AI21 has "numResults" parameter, however we emulate multiple result
# via multiple calls to support all models uniformly.
return ret
def create_request(prompt: str, params: Dict[str, Any]) -> Dict[str, Any]:
return {"prompt": prompt, **params}
class AI21Adapter(PseudoChatModel):
model: str
client: Bedrock
@classmethod
def create(cls, client: Bedrock, model: str):
return cls(
client=client,
model=model,
tokenize_string=default_tokenize_string,
tools_emulator=default_tools_emulator,
chat_emulator=default_emulator,
partitioner=trivial_partitioner,
)
async def predict(
self, consumer: Consumer, params: ModelParameters, prompt: str
):
args = create_request(prompt, convert_params(params))
response, _headers = await self.client.ainvoke_non_streaming(
self.model, args
)
resp = AI21Response.parse_obj(response)
consumer.append_content(resp.content())
consumer.close_content()
consumer.add_usage(resp.usage())