batch/aiml-workloads/src/worker.py (34 lines of code) (raw):
#!/usr/bin/env python
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import rediswq
from model_training import FraudDetectionModelTrainer
# Initialize variables
FILESTORE_PATH = "/mnt/fileserver/"
TESTING_DATASET_PATH = FILESTORE_PATH + "datasets/testing/test_dataset.pkl"
OUTPUT_DIR = FILESTORE_PATH + "output/"
REPORT_PATH = OUTPUT_DIR + "report.txt"
CLASS_LABEL = "TX_FRAUD_SCENARIO"
QUEUE_NAME = "datasets"
HOST = "redis"
def main():
"""
Workload which:
1. Claims a filename from a Redis Worker Queue
2. Reads the dataset from the file
3. Partially trains the model on the dataset
4. Saves a model checkpoint and generates a report on
the performance of the model after the partial training.
5. Removes the filename from the Redis Worker Queue
6. Repeats 1 through 5 till the Queue is empty
"""
q = rediswq.RedisWQ(name="datasets", host=HOST)
print("Worker with sessionID: " + q.sessionID())
print("Initial queue state: empty=" + str(q.empty()))
checkpoint_path = None
while not q.empty():
# Claim item in Redis Worker Queue
item = q.lease(lease_secs=20, block=True, timeout=2)
if item is not None:
dataset_path = item.decode("utf-8")
print("Processing dataset: " + dataset_path)
training_dataset_path = FILESTORE_PATH + dataset_path
# Initialize the model training manager class
model_trainer = FraudDetectionModelTrainer(
training_dataset_path,
TESTING_DATASET_PATH,
CLASS_LABEL,
checkpoint_path=checkpoint_path,
)
# Train model and save checkpoint + report
checkpoint_path = model_trainer.train_and_save(OUTPUT_DIR)
model_trainer.generate_report(REPORT_PATH)
# Remove item from Redis Worker Queue
q.complete(item)
else:
print("Waiting for work")
print("Queue empty, exiting")
# Run workload
main()