ai-ml/gke-ray/rayserve/stable-diffusion/stable_diffusion.py (44 lines of code) (raw):

# Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # NOTE: this file was inspired from https://github.com/ray-project/serve_config_examples/blob/master/stable_diffusion/stable_diffusion.py from io import BytesIO from fastapi import FastAPI from fastapi.responses import Response import torch from ray import serve from ray.serve.handle import DeploymentHandle app = FastAPI() @serve.deployment(num_replicas=1) @serve.ingress(app) class APIIngress: def __init__(self, diffusion_model_handle: DeploymentHandle) -> None: self.handle = diffusion_model_handle @app.get( "/imagine", responses={200: {"content": {"image/png": {}}}}, response_class=Response, ) async def generate(self, prompt: str, img_size: int = 512): assert len(prompt), "prompt parameter cannot be empty" image = await self.handle.generate.remote(prompt, img_size=img_size) file_stream = BytesIO() image.save(file_stream, "PNG") return Response(content=file_stream.getvalue(), media_type="image/png") @serve.deployment( ray_actor_options={"num_gpus": 1}, num_replicas=1, ) class StableDiffusionV2: def __init__(self): from diffusers import EulerDiscreteScheduler, StableDiffusionPipeline model_id = "stabilityai/stable-diffusion-2" scheduler = EulerDiscreteScheduler.from_pretrained( model_id, subfolder="scheduler" ) self.pipe = StableDiffusionPipeline.from_pretrained( model_id, scheduler=scheduler, revision="fp16", torch_dtype=torch.float16 ) self.pipe = self.pipe.to("cuda") def generate(self, prompt: str, img_size: int = 512): assert len(prompt), "prompt parameter cannot be empty" with torch.autocast("cuda"): image = self.pipe(prompt, height=img_size, width=img_size).images[0] return image entrypoint = APIIngress.bind(StableDiffusionV2.bind())