ai-ml/gke-ray/rayserve/stable-diffusion/stable_diffusion_tpu.py (95 lines of code) (raw):

# Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # NOTE: this file was inspired from https://github.com/ray-project/serve_config_examples/blob/master/stable_diffusion/stable_diffusion.py """Ray Serve Stable Diffusion example.""" from io import BytesIO from typing import List from fastapi import FastAPI from fastapi.responses import Response import logging import ray from ray import serve import time app = FastAPI() _MAX_BATCH_SIZE = 64 logger = logging.getLogger("ray.serve") @serve.deployment(num_replicas=1) @serve.ingress(app) class APIIngress: def __init__(self, diffusion_model_handle) -> None: self.handle = diffusion_model_handle @app.get( "/imagine", responses={200: {"content": {"image/png": {}}}}, response_class=Response, ) async def generate(self, prompt: str): assert len(prompt), "prompt parameter cannot be empty" image = await self.handle.generate.remote(prompt) return image @serve.deployment( ray_actor_options={ "resources": {"TPU": 4}, }, ) class StableDiffusion: """FLAX Stable Diffusion Ray Serve deployment running on TPUs. Attributes: run_with_profiler: Whether or not to run with the profiler. Note that this saves the profile to the separate TPU VM. """ def __init__( self, run_with_profiler: bool = False, warmup: bool = False, warmup_batch_size: int = _MAX_BATCH_SIZE): from diffusers import FlaxStableDiffusionPipeline from flax.jax_utils import replicate import jax import jax.numpy as jnp from jax import pmap model_id = "CompVis/stable-diffusion-v1-4" self._pipeline, params = FlaxStableDiffusionPipeline.from_pretrained( model_id, revision="bf16", dtype=jnp.bfloat16) self._p_params = replicate(params) self._p_generate = pmap(self._pipeline._generate) self._run_with_profiler = run_with_profiler self._profiler_dir = "/tmp/tensorboard" if warmup: logger.info("Sending warmup requests.") warmup_prompts = ["A warmup request"] * warmup_batch_size self.generate_tpu(warmup_prompts) def generate_tpu(self, prompts: List[str]): """Generates a batch of images from Diffusion from a list of prompts. Args: prompts: a list of strings. Should be a factor of 4. Returns: A list of PIL Images. """ from flax.training.common_utils import shard import jax import numpy as np rng = jax.random.PRNGKey(0) rng = jax.random.split(rng, jax.device_count()) assert prompts, "prompt parameter cannot be empty" logger.info("Prompts: %s", prompts) prompt_ids = self._pipeline.prepare_inputs(prompts) prompt_ids = shard(prompt_ids) logger.info("Sharded prompt ids has shape: %s", prompt_ids.shape) if self._run_with_profiler: jax.profiler.start_trace(self._profiler_dir) time_start = time.time() images = self._p_generate(prompt_ids, self._p_params, rng) images = images.block_until_ready() elapsed = time.time() - time_start if self._run_with_profiler: jax.profiler.stop_trace() logger.info("Inference time (in seconds): %f", elapsed) logger.info("Shape of the predictions: %s", images.shape) images = images.reshape( (images.shape[0] * images.shape[1],) + images.shape[-3:]) logger.info("Shape of images afterwards: %s", images.shape) return self._pipeline.numpy_to_pil(np.array(images)) @serve.batch(batch_wait_timeout_s=10, max_batch_size=_MAX_BATCH_SIZE) async def batched_generate_handler(self, prompts: List[str]): """Sends a batch of prompts to the TPU model server. This takes advantage of @serve.batch, Ray Serve's built-in batching mechanism. Args: prompts: A list of input prompts Returns: A list of responses which contents are raw PNG. """ logger.info("Number of input prompts: %d", len(prompts)) num_to_pad = _MAX_BATCH_SIZE - len(prompts) prompts += ["Scratch request"] * num_to_pad images = self.generate_tpu(prompts) results = [] for image in images[: _MAX_BATCH_SIZE - num_to_pad]: file_stream = BytesIO() image.save(file_stream, "PNG") results.append( Response(content=file_stream.getvalue(), media_type="image/png") ) return results async def generate(self, prompt): return await self.batched_generate_handler(prompt) diffusion_bound = StableDiffusion.bind() deployment = APIIngress.bind(diffusion_bound)