def inference_interface()

in ai-ml/llm-serving-gemma/gradio/app/app.py [0:0]


def inference_interface(message, history, model_temperature, top_p, max_tokens):

    json_message = {}

    # Need to determine the engine to determine input/output formats
    if "LLM_ENGINE" in os.environ:
        llm_engine = os.environ["LLM_ENGINE"]
    else:
        llm_engine = "openai-chat"

    match llm_engine:
        case "max":
            json_message.update({"temperature": model_temperature})
            json_message.update({"top_p": top_p})
            json_message.update({"max_tokens": max_tokens})
            final_message = process_message(message, history)

            json_message.update({"prompt": final_message})
            json_data = post_request(json_message)

            temp_output = json_data["response"]
            output = temp_output
        case "vllm":
            json_message.update({"temperature": model_temperature})
            json_message.update({"top_p": top_p})
            json_message.update({"max_tokens": max_tokens})
            final_message = process_message(message, history)

            json_message.update({"prompt": final_message})
            json_data = post_request(json_message)

            temp_output = json_data["predictions"][0]
            output = temp_output.split("Output:\n", 1)[1]
        case "tgi":
            json_message.update({"parameters": {}})
            json_message["parameters"].update({"temperature": model_temperature})
            json_message["parameters"].update({"top_p": top_p})
            json_message["parameters"].update({"max_new_tokens": max_tokens})
            final_message = process_message(message, history)

            json_message.update({"inputs": final_message})
            json_data = post_request(json_message)

            temp_output = json_data["generated_text"]
            output = temp_output
        case _:
            print("* History: " + str(history))
            json_message.update({"model": model_id})
            json_message.update({"messages": []})
            # originally this was defaulted, so user would have to manually set this value to disable the prompt
            if not disable_system_message:
                system_message = {
                    "role": "system",
                    "content": "You are a helpful assistant.",
                }
                json_message["messages"].append(system_message)

            json_message["temperature"] = model_temperature

            if len(history) > 0:
                # we have history
                print(
                    "** Before adding additional messages: "
                    + str(json_message["messages"])
                )
                for item in history:
                    user_message = {"role": "user", "content": item[0]}
                    assistant_message = {"role": "assistant", "content": item[1]}
                    json_message["messages"].append(user_message)
                    json_message["messages"].append(assistant_message)

            new_user_message = {"role": "user", "content": message}
            json_message["messages"].append(new_user_message)

            json_data = post_request(json_message)
            output = json_data["choices"][0]["message"]["content"]

    return output