utils/pytorch_util.py (34 lines of code) (raw):

# Copyright 2022 Twitter, Inc. # SPDX-License-Identifier: Apache-2.0 import torch import numpy as np def soft_update_from_to(source, target, tau): for target_param, param in zip(target.parameters(), source.parameters()): target_param.data.copy_( target_param.data * (1.0 - tau) + param.data * tau ) def copy_model_params_from_to(source, target): for target_param, param in zip(target.parameters(), source.parameters()): target_param.data.copy_(param.data) def fanin_init(tensor, scale=1): size = tensor.size() if len(size) == 2: fan_in = size[0] elif len(size) > 2: fan_in = np.prod(size[1:]) else: raise Exception("Shape must be have dimension at least 2.") bound = scale / np.sqrt(fan_in) return tensor.data.uniform_(-bound, bound) def orthogonal_init(tensor, gain=0.01): torch.nn.init.orthogonal_(tensor, gain=gain) def fanin_init_weights_like(tensor): size = tensor.size() if len(size) == 2: fan_in = size[0] elif len(size) > 2: fan_in = np.prod(size[1:]) else: raise Exception("Shape must be have dimension at least 2.") bound = 1. / np.sqrt(fan_in) new_tensor = torch.FloatTensor(tensor.size()) new_tensor.uniform_(-bound, bound) return new_tensor