ai-ml/t5-model-serving/model/handler.py (68 lines of code) (raw):
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import os
import logging
import json
from abc import ABC
from ts.torch_handler.base_handler import BaseHandler
from transformers import T5Tokenizer, T5ForConditionalGeneration
logger = logging.getLogger(__name__)
class TransformersSeqGeneration(BaseHandler, ABC):
_LANG_MAP = {
"ro": "Romanian",
"fr": "French",
"de": "German",
"en": "English",
}
def __init__(self):
super().__init__()
self.initialized = False
def initialize(self, ctx):
self.manifest = ctx.manifest
properties = ctx.system_properties
model_dir = properties.get("model_dir")
serialized_file = self.manifest["model"]["serializedFile"]
model_pt_path = os.path.join(model_dir, serialized_file)
self.device = torch.device(
"cuda:" + str(properties.get("gpu_id"))
if torch.cuda.is_available()
else "cpu"
)
# read configs for the mode, model_name, etc. from setup_config.json
setup_config_path = os.path.join(model_dir, "setup_config.json")
if os.path.isfile(setup_config_path):
with open(setup_config_path) as setup_config_file:
self.setup_config = json.load(setup_config_file)
else:
logger.warning("Missing the setup_config.json file.")
# Loading the model and tokenizer from checkpoint and config files based on the user's choice of mode
# further setup config can be added.
self.tokenizer = T5Tokenizer.from_pretrained(model_dir)
if self.setup_config["save_mode"] == "torchscript":
self.model = torch.jit.load(model_pt_path)
elif self.setup_config["save_mode"] == "pretrained":
self.model = T5ForConditionalGeneration.from_pretrained(model_dir)
else:
logger.warning("Missing the checkpoint or state_dict.")
self.model.to(self.device)
self.model.eval()
logger.info("Transformer model from path %s loaded successfully", model_dir)
self.initialized = True
def preprocess(self, requests):
input_batch = None
texts_batch = []
for idx, data in enumerate(requests):
data = data["body"]
input_text = data["text"]
src_lang = data["from"]
tgt_lang = data["to"]
if isinstance(input_text, (bytes, bytearray)):
input_text = input_text.decode("utf-8")
src_lang = src_lang.decode("utf-8")
tgt_lang = tgt_lang.decode("utf-8")
texts_batch.append(f"translate {self._LANG_MAP[src_lang]} to {self._LANG_MAP[tgt_lang]}: {input_text}")
inputs = self.tokenizer(texts_batch, return_tensors="pt")
input_batch = inputs["input_ids"].to(self.device)
return input_batch
def inference(self, input_batch):
generations = self.model.generate(input_batch)
generations = self.tokenizer.batch_decode(generations, skip_special_tokens=True)
return generations
def postprocess(self, inference_output):
return [{"text": text} for text in inference_output]