aidial_assistant/utils/state.py (91 lines of code) (raw):

import json from typing import TypedDict from aidial_sdk.chat_completion.request import CustomContent, Message, Role from aidial_assistant.chain.command_result import ( CommandInvocation, commands_to_text, ) from aidial_assistant.chain.history import MessageScope, ScopedMessage from aidial_assistant.utils.exceptions import RequestParameterValidationError from aidial_assistant.utils.open_ai import ( assistant_message, system_message, user_message, ) class Invocation(TypedDict): index: str | int request: str response: str class State(TypedDict, total=False): invocations: list[Invocation] def _get_invocations(custom_content: CustomContent | None) -> list[Invocation]: if custom_content is None: return [] state: State | None = custom_content.state if state is None: return [] invocations: list[Invocation] | None = state.get("invocations") if invocations is None: return [] invocations.sort(key=lambda invocation: int(invocation["index"])) return invocations def _convert_old_commands(string: str) -> str: """Converts old commands to new format. Previously saved conversations with assistant will stop working if state is not updated. Old format: {"commands": [{"command": "run-addon", "args": ["<addon-name>", "<query>"]}]} New format: {"commands": [{"command": "<addon-name>", "arguments": {"query": "<query>"}}]} """ commands = json.loads(string) result: list[CommandInvocation] = [] for command in commands["commands"]: command_name = command["command"] # run-addon was previously called run-plugin if command_name in ("run-addon", "run-plugin"): args = command["args"] result.append( CommandInvocation(command=args[0], arguments={"query": args[1]}) ) else: result.append(command) return commands_to_text(result) def parse_history(history: list[Message]) -> list[ScopedMessage]: messages: list[ScopedMessage] = [] for index, message in enumerate(history): if message.role == Role.ASSISTANT: invocations = _get_invocations(message.custom_content) for invocation in invocations: messages.append( ScopedMessage( scope=MessageScope.INTERNAL, message=assistant_message( _convert_old_commands(invocation["request"]) ), user_index=index, ) ) messages.append( ScopedMessage( scope=MessageScope.INTERNAL, message=user_message(invocation["response"]), user_index=index, ) ) messages.append( ScopedMessage( message=assistant_message(message.content or ""), user_index=index, ) ) elif message.role == Role.USER: messages.append( ScopedMessage( message=user_message(message.content or ""), user_index=index, ) ) elif message.role == Role.SYSTEM: messages.append( ScopedMessage( message=system_message(message.content or ""), user_index=index, ) ) else: raise RequestParameterValidationError( f"Role {message.role} is not supported.", param="messages" ) return messages