Source code for textbox.utils.utils

# -*- coding: utf-8 -*-
# @Time   : 2020/11/14
# @Author : Junyi Li, Gaole He
# @Email  : lijunyi@ruc.edu.cn

# UPDATE:
# @Time   : 2020/11/15
# @Author : Tianyi Tang
# @Email  : steventang@ruc.edu.cn

"""
textbox.utils.utils
################################
"""

import os
import datetime
import importlib
import random
import torch
import numpy as np
from textbox.utils.enum_type import ModelType, PLM_MODELS


[docs]def get_local_time(): r"""Get current time Returns: str: current time """ cur = datetime.datetime.now() cur = cur.strftime('%b-%d-%Y_%H-%M-%S') return cur
[docs]def ensure_dir(dir_path): r"""Make sure the directory exists, if it does not exist, create it Args: dir_path (str): directory path """ if not os.path.exists(dir_path): os.makedirs(dir_path)
[docs]def get_model(model_name): r"""Automatically select model class based on model name Args: model_name (str): model name Returns: Generator: model class """ model_submodule = ['GAN', 'LM', 'VAE', 'Seq2Seq', 'Attribute', 'Kb2Text'] try: model_name = 'Transformers' if model_name.lower() in PLM_MODELS else model_name model_file_name = model_name.lower() for submodule in model_submodule: module_path = '.'.join(['...model', submodule, model_file_name]) if importlib.util.find_spec(module_path, __name__): model_module = importlib.import_module(module_path, __name__) model_class = getattr(model_module, model_name) except: raise NotImplementedError("{} can't be found".format(model_file_name)) return model_class
[docs]def get_trainer(model_type, model_name): r"""Automatically select trainer class based on model type and model name Args: model_type (~textbox.utils.enum_type.ModelType): model type model_name (str): model name Returns: ~textbox.trainer.trainer.Trainer: trainer class """ try: return getattr(importlib.import_module('textbox.trainer'), model_name + 'Trainer') except AttributeError: if model_type in [ModelType.UNCONDITIONAL]: return getattr(importlib.import_module('textbox.trainer'), 'Trainer') elif model_type == ModelType.GAN: return getattr(importlib.import_module('textbox.trainer'), 'GANTrainer') elif model_type in [ModelType.SEQ2SEQ, ModelType.ATTRIBUTE]: return getattr(importlib.import_module('textbox.trainer'), 'Seq2SeqTrainer') else: return getattr(importlib.import_module('textbox.trainer'), 'Trainer')
[docs]def early_stopping(value, best, cur_step, max_step, bigger=True): r""" validation-based early stopping Args: value (float): current result best (float): best result cur_step (int): the number of consecutive steps that did not exceed the best result max_step (int): threshold steps for stopping bigger (bool, optional): whether the bigger the better Returns: tuple: - float, best result after this step - int, the number of consecutive steps that did not exceed the best result after this step - bool, whether to stop - bool, whether to update """ stop_flag = False update_flag = False if bigger: if value > best: cur_step = 0 best = value update_flag = True else: cur_step += 1 if cur_step > max_step: stop_flag = True else: if value < best: cur_step = 0 best = value update_flag = True else: cur_step += 1 if cur_step > max_step: stop_flag = True return best, cur_step, stop_flag, update_flag
[docs]def init_seed(seed, reproducibility): r""" init random seed for random functions in numpy, torch, cuda and cudnn Args: seed (int): random seed reproducibility (bool): Whether to require reproducibility """ random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) if reproducibility: torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True else: torch.backends.cudnn.benchmark = True torch.backends.cudnn.deterministic = False