ai-ml/t5-model-serving/model/handler.py (68 lines of code) (raw):

# Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch import os import logging import json from abc import ABC from ts.torch_handler.base_handler import BaseHandler from transformers import T5Tokenizer, T5ForConditionalGeneration logger = logging.getLogger(__name__) class TransformersSeqGeneration(BaseHandler, ABC): _LANG_MAP = { "ro": "Romanian", "fr": "French", "de": "German", "en": "English", } def __init__(self): super().__init__() self.initialized = False def initialize(self, ctx): self.manifest = ctx.manifest properties = ctx.system_properties model_dir = properties.get("model_dir") serialized_file = self.manifest["model"]["serializedFile"] model_pt_path = os.path.join(model_dir, serialized_file) self.device = torch.device( "cuda:" + str(properties.get("gpu_id")) if torch.cuda.is_available() else "cpu" ) # read configs for the mode, model_name, etc. from setup_config.json setup_config_path = os.path.join(model_dir, "setup_config.json") if os.path.isfile(setup_config_path): with open(setup_config_path) as setup_config_file: self.setup_config = json.load(setup_config_file) else: logger.warning("Missing the setup_config.json file.") # Loading the model and tokenizer from checkpoint and config files based on the user's choice of mode # further setup config can be added. self.tokenizer = T5Tokenizer.from_pretrained(model_dir) if self.setup_config["save_mode"] == "torchscript": self.model = torch.jit.load(model_pt_path) elif self.setup_config["save_mode"] == "pretrained": self.model = T5ForConditionalGeneration.from_pretrained(model_dir) else: logger.warning("Missing the checkpoint or state_dict.") self.model.to(self.device) self.model.eval() logger.info("Transformer model from path %s loaded successfully", model_dir) self.initialized = True def preprocess(self, requests): input_batch = None texts_batch = [] for idx, data in enumerate(requests): data = data["body"] input_text = data["text"] src_lang = data["from"] tgt_lang = data["to"] if isinstance(input_text, (bytes, bytearray)): input_text = input_text.decode("utf-8") src_lang = src_lang.decode("utf-8") tgt_lang = tgt_lang.decode("utf-8") texts_batch.append(f"translate {self._LANG_MAP[src_lang]} to {self._LANG_MAP[tgt_lang]}: {input_text}") inputs = self.tokenizer(texts_batch, return_tensors="pt") input_batch = inputs["input_ids"].to(self.device) return input_batch def inference(self, input_batch): generations = self.model.generate(input_batch) generations = self.tokenizer.batch_decode(generations, skip_special_tokens=True) return generations def postprocess(self, inference_output): return [{"text": text} for text in inference_output]