ai-ml/gke-ray/rayserve/stable-diffusion/stable_diffusion_tpu_req.py (113 lines of code) (raw):
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
from concurrent import futures
import functools
from io import BytesIO
import numpy as np
from PIL import Image
import requests
from tqdm import tqdm
_PROMPTS = [
"Labrador in the style of Hokusai",
"Painting of a squirrel skating in New York",
"HAL-9000 in the style of Van Gogh",
"Times Square under water, with fish and a dolphin swimming around",
"Ancient Roman fresco showing a man working on his laptop",
"Armchair in the shape of an avocado",
"Clown astronaut in space, with Earth in the background",
"A cat sitting on a windowsill",
"A dog playing fetch in a park",
"A city skyline at night",
"A field of flowers in bloom",
"A tropical beach with palm trees",
"A snowy mountain range",
"A waterfall cascading into a pool",
"A forest at sunset",
"A desert landscape with cacti",
"A volcano erupting",
"A lightning storm in the distance",
"A rainbow over a rainbow",
"A unicorn grazing in a meadow",
"A dragon flying through the sky",
"A mermaid swimming in the ocean",
"A robot walking down the street",
"A UFO landing in a field",
"A portal to another dimension",
"A time traveler from the future",
"A talking cat",
"A bowl of fruit on a table",
"A group of friends laughing",
"A family sitting down for dinner",
"A couple kissing in the rain",
"A child playing with a toy",
"A musician playing an instrument",
"A painter painting a picture",
"A writer writing a book",
"A scientist conducting an experiment",
"A construction worker building a house",
"A doctor operating on a patient",
"A teacher teaching a class",
"A police officer arresting a suspect",
"A firefighter putting out a fire",
"A soldier fighting in a war",
"A farmer working in a field",
"A pilot flying a plane",
"An astronaut in space",
"A unicorn eating a rainbow"
]
def send_request_and_receive_image(prompt: str, url: str) -> BytesIO:
"""Sends a single prompt request and returns the Image."""
try:
inputs = "%20".join(prompt.split(" "))
resp = requests.get(f"{url}?prompt={inputs}")
resp.raise_for_status()
return BytesIO(resp.content)
except requests.RequestException as e:
print(f"An error occurred while sending the request: {e}")
def image_grid(imgs, rows, cols):
w, h = imgs[0].size
grid = Image.new("RGB", size=(cols * w, rows * h))
for i, img in enumerate(imgs):
grid.paste(img, box=(i % cols * w, i // cols * h))
return grid
def send_requests(num_requests: int, batch_size: int, save_pictures: bool,
url: str = "http://localhost:8000/imagine"):
"""Sends a list of requests and processes the responses."""
print("num_requests: ", num_requests)
print("batch_size: ", batch_size)
print("url: ", url)
print("save_pictures: ", save_pictures)
prompts = _PROMPTS
if num_requests > len(_PROMPTS):
# Repeat until larger than num_requests
prompts = _PROMPTS * int(np.ceil(num_requests / len(_PROMPTS)))
prompts = np.random.choice(
prompts, num_requests, replace=False)
with futures.ThreadPoolExecutor(max_workers=batch_size) as executor:
raw_images = list(
tqdm(
executor.map(
functools.partial(send_request_and_receive_image, url=url),
prompts,
),
total=len(prompts),
)
)
if save_pictures:
print("Saving pictures to diffusion_results.png")
images = [Image.open(raw_image) for raw_image in raw_images]
grid = image_grid(images, 2, num_requests // 2)
grid.save("./diffusion_results.png")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Sends requests to Diffusion.")
parser.add_argument(
"--num_requests", help="Number of requests to send.",
default=8)
parser.add_argument(
"--batch_size", help="The number of requests to send at a time.",
default=8)
parser.add_argument(
"--save_pictures", default=False, action="store_true",
help="Whether to save the generated pictures to disk.")
parser.add_argument(
"--ip", help="The IP address to send the requests to.")
args = parser.parse_args()
send_requests(
num_requests=int(args.num_requests), batch_size=int(args.batch_size),
save_pictures=bool(args.save_pictures))