ai-ml/gke-ray/rayserve/llm/tpu/serve_tpu.py (87 lines of code) (raw):

# Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # NOTE: this file was inspired from: https://github.com/richardsliu/vllm/blob/rayserve/examples/rayserve_tpu.py import os import json import logging from typing import Dict, List, Optional import ray from fastapi import FastAPI from ray import serve from starlette.requests import Request from starlette.responses import Response from vllm import LLM, SamplingParams logger = logging.getLogger("ray.serve") app = FastAPI() @serve.deployment(name="VLLMDeployment") @serve.ingress(app) class VLLMDeployment: def __init__( self, model_id, num_tpu_chips, max_model_len, tokenizer_mode, dtype, ): self.llm = LLM( model=model_id, tensor_parallel_size=num_tpu_chips, max_model_len=max_model_len, dtype=dtype, download_dir=os.environ['VLLM_XLA_CACHE_PATH'], # Error if not provided. tokenizer_mode=tokenizer_mode, enforce_eager=True, ) @app.post("/v1/generate") async def generate(self, request: Request): request_dict = await request.json() prompts = request_dict.pop("prompt") max_toks = int(request_dict.pop("max_tokens")) print("Processing prompt ", prompts) sampling_params = SamplingParams(temperature=0.7, top_p=1.0, n=1, max_tokens=max_toks) outputs = self.llm.generate(prompts, sampling_params) for output in outputs: prompt = output.prompt generated_text = "" token_ids = [] for completion_output in output.outputs: generated_text += completion_output.text token_ids.extend(list(completion_output.token_ids)) print("Generated text: ", generated_text) ret = { "prompt": prompt, "text": generated_text, "token_ids": token_ids, } return Response(content=json.dumps(ret)) def get_num_tpu_chips() -> int: if "TPU" not in ray.cluster_resources(): # Pass in TPU chips when the current Ray cluster resources can't be auto-detected (i.e for autoscaling). if os.environ.get('TPU_CHIPS') is not None: return int(os.environ.get('TPU_CHIPS')) return 0 return int(ray.cluster_resources()["TPU"]) def get_max_model_len() -> Optional[int]: if 'MAX_MODEL_LEN' not in os.environ or os.environ['MAX_MODEL_LEN'] == "": return None return int(os.environ['MAX_MODEL_LEN']) def get_tokenizer_mode() -> str: if 'TOKENIZER_MODE' not in os.environ or os.environ['TOKENIZER_MODE'] == "": return "auto" return os.environ['TOKENIZER_MODE'] def get_dtype() -> str: if 'DTYPE' not in os.environ or os.environ['DTYPE'] == "": return "auto" return os.environ['DTYPE'] def build_app(cli_args: Dict[str, str]) -> serve.Application: """Builds the Serve app based on CLI arguments.""" ray.init(ignore_reinit_error=True, address="ray://localhost:10001") model_id = os.environ['MODEL_ID'] num_tpu_chips = get_num_tpu_chips() pg_resources = [] pg_resources.append({"CPU": 1}) # for the deployment replica for i in range(num_tpu_chips): pg_resources.append({"CPU": 1, "TPU": 1}) # for the vLLM actors # Use PACK strategy since the deployment may use more than one TPU node. return VLLMDeployment.options( placement_group_bundles=pg_resources, placement_group_strategy="PACK").bind(model_id, num_tpu_chips, get_max_model_len(), get_tokenizer_mode(), get_dtype()) model = build_app({})