ai-ml/llm-serving-gemma/gradio/app/app.py (112 lines of code) (raw):
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import requests
import gradio as gr
import os
if "MODEL_ID" in os.environ:
model_id = os.environ["MODEL_ID"]
else:
model_id = "gradio"
disable_system_message = False
if "DISABLE_SYSTEM_MESSAGE" in os.environ:
disable_system_message = os.environ["DISABLE_SYSTEM_MESSAGE"]
def inference_interface(message, history, model_temperature, top_p, max_tokens):
json_message = {}
# Need to determine the engine to determine input/output formats
if "LLM_ENGINE" in os.environ:
llm_engine = os.environ["LLM_ENGINE"]
else:
llm_engine = "openai-chat"
match llm_engine:
case "max":
json_message.update({"temperature": model_temperature})
json_message.update({"top_p": top_p})
json_message.update({"max_tokens": max_tokens})
final_message = process_message(message, history)
json_message.update({"prompt": final_message})
json_data = post_request(json_message)
temp_output = json_data["response"]
output = temp_output
case "vllm":
json_message.update({"temperature": model_temperature})
json_message.update({"top_p": top_p})
json_message.update({"max_tokens": max_tokens})
final_message = process_message(message, history)
json_message.update({"prompt": final_message})
json_data = post_request(json_message)
temp_output = json_data["predictions"][0]
output = temp_output.split("Output:\n", 1)[1]
case "tgi":
json_message.update({"parameters": {}})
json_message["parameters"].update({"temperature": model_temperature})
json_message["parameters"].update({"top_p": top_p})
json_message["parameters"].update({"max_new_tokens": max_tokens})
final_message = process_message(message, history)
json_message.update({"inputs": final_message})
json_data = post_request(json_message)
temp_output = json_data["generated_text"]
output = temp_output
case _:
print("* History: " + str(history))
json_message.update({"model": model_id})
json_message.update({"messages": []})
# originally this was defaulted, so user would have to manually set this value to disable the prompt
if not disable_system_message:
system_message = {
"role": "system",
"content": "You are a helpful assistant.",
}
json_message["messages"].append(system_message)
json_message["temperature"] = model_temperature
if len(history) > 0:
# we have history
print(
"** Before adding additional messages: "
+ str(json_message["messages"])
)
for item in history:
user_message = {"role": "user", "content": item[0]}
assistant_message = {"role": "assistant", "content": item[1]}
json_message["messages"].append(user_message)
json_message["messages"].append(assistant_message)
new_user_message = {"role": "user", "content": message}
json_message["messages"].append(new_user_message)
json_data = post_request(json_message)
output = json_data["choices"][0]["message"]["content"]
return output
def process_message(message, history):
user_prompt_format = ""
system_prompt_format = ""
# if env prompts are set, use those
if "USER_PROMPT" in os.environ:
user_prompt_format = os.environ["USER_PROMPT"]
if "SYSTEM_PROMPT" in os.environ:
system_prompt_format = os.environ["SYSTEM_PROMPT"]
print("* History: " + str(history))
user_message = ""
system_message = ""
history_message = ""
if len(history) > 0:
# we have history
for item in history:
user_message = user_prompt_format.replace("prompt", item[0])
system_message = system_prompt_format.replace("prompt", item[1])
history_message = history_message + user_message + system_message
new_user_message = user_prompt_format.replace("prompt", message)
# append the history with the new message and close with the turn
aggregated_message = history_message + new_user_message
return aggregated_message
def post_request(json_message):
print("*** Request" + str(json_message), flush=True)
response = requests.post(
os.environ["HOST"] + os.environ["CONTEXT_PATH"], json=json_message
)
json_data = response.json()
print("*** Output: " + str(json_data), flush=True)
return json_data
with gr.Blocks(fill_height=True) as app:
html_text = "You are chatting with: " + model_id
gr.HTML(value=html_text)
model_temperature = gr.Slider(
minimum=0.1, maximum=1.0, value=0.9, label="Temperature", render=False
)
top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, label="Top_p", render=False)
max_tokens = gr.Slider(
minimum=1, maximum=4096, value=256, label="Max Tokens", render=False
)
gr.ChatInterface(
inference_interface, additional_inputs=[model_temperature, top_p, max_tokens]
)
app.launch(server_name="0.0.0.0")