utils/utils.py (117 lines of code) (raw):

# Copyright 2022 Twitter, Inc. # SPDX-License-Identifier: Apache-2.0 import time import math import torch import numpy as np def print_banner(s, separator="-", num_star=60): print(separator * num_star, flush=True) print(s, flush=True) print(separator * num_star, flush=True) class Progress: def __init__(self, total, name='Progress', ncol=3, max_length=20, indent=0, line_width=100, speed_update_freq=100): self.total = total self.name = name self.ncol = ncol self.max_length = max_length self.indent = indent self.line_width = line_width self._speed_update_freq = speed_update_freq self._step = 0 self._prev_line = '\033[F' self._clear_line = ' ' * self.line_width self._pbar_size = self.ncol * self.max_length self._complete_pbar = '#' * self._pbar_size self._incomplete_pbar = ' ' * self._pbar_size self.lines = [''] self.fraction = '{} / {}'.format(0, self.total) self.resume() def update(self, description, n=1): self._step += n if self._step % self._speed_update_freq == 0: self._time0 = time.time() self._step0 = self._step self.set_description(description) def resume(self): self._skip_lines = 1 print('\n', end='') self._time0 = time.time() self._step0 = self._step def pause(self): self._clear() self._skip_lines = 1 def set_description(self, params=[]): if type(params) == dict: params = sorted([ (key, val) for key, val in params.items() ]) ############ # Position # ############ self._clear() ########### # Percent # ########### percent, fraction = self._format_percent(self._step, self.total) self.fraction = fraction ######### # Speed # ######### speed = self._format_speed(self._step) ########## # Params # ########## num_params = len(params) nrow = math.ceil(num_params / self.ncol) params_split = self._chunk(params, self.ncol) params_string, lines = self._format(params_split) self.lines = lines description = '{} | {}{}'.format(percent, speed, params_string) print(description) self._skip_lines = nrow + 1 def append_description(self, descr): self.lines.append(descr) def _clear(self): position = self._prev_line * self._skip_lines empty = '\n'.join([self._clear_line for _ in range(self._skip_lines)]) print(position, end='') print(empty) print(position, end='') def _format_percent(self, n, total): if total: percent = n / float(total) complete_entries = int(percent * self._pbar_size) incomplete_entries = self._pbar_size - complete_entries pbar = self._complete_pbar[:complete_entries] + self._incomplete_pbar[:incomplete_entries] fraction = '{} / {}'.format(n, total) string = '{} [{}] {:3d}%'.format(fraction, pbar, int(percent * 100)) else: fraction = '{}'.format(n) string = '{} iterations'.format(n) return string, fraction def _format_speed(self, n): num_steps = n - self._step0 t = time.time() - self._time0 speed = num_steps / t string = '{:.1f} Hz'.format(speed) if num_steps > 0: self._speed = string return string def _chunk(self, l, n): return [l[i:i + n] for i in range(0, len(l), n)] def _format(self, chunks): lines = [self._format_chunk(chunk) for chunk in chunks] lines.insert(0, '') padding = '\n' + ' ' * self.indent string = padding.join(lines) return string, lines def _format_chunk(self, chunk): line = ' | '.join([self._format_param(param) for param in chunk]) return line def _format_param(self, param): k, v = param return '{} : {}'.format(k, v)[:self.max_length] def stamp(self): if self.lines != ['']: params = ' | '.join(self.lines) string = '[ {} ] {}{} | {}'.format(self.name, self.fraction, params, self._speed) self._clear() print(string, end='\n') self._skip_lines = 1 else: self._clear() self._skip_lines = 0 def close(self): self.pause() class Silent: def __init__(self, *args, **kwargs): pass def __getattr__(self, attr): return lambda *args: None