ai-ml/gke-ray/rayserve/llm/serve_chat_completion.py (87 lines of code) (raw):

# Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # NOTE: this file was inspired from: https://github.com/ray-project/ray/blob//master/doc/source/serve/doc_code/vllm_example.py import os from typing import Dict, Optional, List import logging from fastapi import FastAPI from starlette.requests import Request from starlette.responses import StreamingResponse, JSONResponse from ray import serve from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.entrypoints.openai.cli_args import make_arg_parser from vllm.entrypoints.openai.protocol import ( ChatCompletionRequest, ChatCompletionResponse, ErrorResponse, ) from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_engine import LoRAModulePath logger = logging.getLogger("ray.serve") app = FastAPI() @serve.deployment(name="VLLMDeployment") @serve.ingress(app) class VLLMDeployment: def __init__( self, engine_args: AsyncEngineArgs, response_role: str, lora_modules: Optional[List[LoRAModulePath]] = None, chat_template: Optional[str] = None, ): logger.info(f"Starting with engine args: {engine_args}") self.openai_serving_chat = None self.engine_args = engine_args self.response_role = response_role self.lora_modules = lora_modules self.chat_template = chat_template self.engine = AsyncLLMEngine.from_engine_args(engine_args) @app.post("/v1/chat/completions") async def create_chat_completion( self, request: ChatCompletionRequest, raw_request: Request ): """OpenAI-compatible HTTP endpoint. API reference: - https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html """ if not self.openai_serving_chat: model_config = await self.engine.get_model_config() # Determine the name of the served model for the OpenAI client. if self.engine_args.served_model_name is not None: served_model_names = self.engine_args.served_model_name else: served_model_names = [self.engine_args.model] self.openai_serving_chat = OpenAIServingChat( self.engine, model_config, served_model_names, self.response_role, self.lora_modules, self.chat_template, ) logger.info(f"Request: {request}") generator = await self.openai_serving_chat.create_chat_completion( request, raw_request ) if isinstance(generator, ErrorResponse): return JSONResponse( content=generator.model_dump(), status_code=generator.code ) if request.stream: return StreamingResponse(content=generator, media_type="text/event-stream") else: assert isinstance(generator, ChatCompletionResponse) return JSONResponse(content=generator.model_dump()) def parse_vllm_args(cli_args: Dict[str, str]): """Parses vLLM args based on CLI inputs. Currently uses argparse because vLLM doesn't expose Python models for all of the config options we want to support. """ parser = make_arg_parser() arg_strings = [] for key, value in cli_args.items(): arg_strings.extend([f"--{key}", str(value)]) logger.info(arg_strings) parsed_args = parser.parse_args(args=arg_strings) return parsed_args def build_app(cli_args: Dict[str, str]) -> serve.Application: """Builds the Serve app based on CLI arguments. See https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html#command-line-arguments-for-the-server for the complete set of arguments. Supported engine arguments: https://docs.vllm.ai/en/latest/models/engine_args.html. """ # noqa: E501 parsed_args = parse_vllm_args(cli_args) engine_args = AsyncEngineArgs.from_cli_args(parsed_args) engine_args.worker_use_ray = True return VLLMDeployment.bind( engine_args, parsed_args.response_role, parsed_args.lora_modules, parsed_args.chat_template, ) model = build_app( {"model": os.environ['MODEL_ID'], "tensor-parallel-size": os.environ['TENSOR_PARALLELISM']})