# -*- coding: utf-8 -*-
# @Time : 2020/11/14
# @Author : Junyi Li, Gaole He
# @Email : lijunyi@ruc.edu.cn
# UPDATE:
# @Time : 2020/11/15
# @Author : Tianyi Tang
# @Email : steventang@ruc.edu.cn
"""
textbox.utils.utils
################################
"""
import os
import datetime
import importlib
import random
import torch
import numpy as np
from textbox.utils.enum_type import ModelType, PLM_MODELS
[docs]def get_local_time():
r"""Get current time
Returns:
str: current time
"""
cur = datetime.datetime.now()
cur = cur.strftime('%b-%d-%Y_%H-%M-%S')
return cur
[docs]def ensure_dir(dir_path):
r"""Make sure the directory exists, if it does not exist, create it
Args:
dir_path (str): directory path
"""
if not os.path.exists(dir_path):
os.makedirs(dir_path)
[docs]def get_model(model_name):
r"""Automatically select model class based on model name
Args:
model_name (str): model name
Returns:
Generator: model class
"""
model_submodule = ['GAN', 'LM', 'VAE', 'Seq2Seq', 'Attribute', 'Kb2Text']
try:
model_name = 'Transformers' if model_name.lower() in PLM_MODELS else model_name
model_file_name = model_name.lower()
for submodule in model_submodule:
module_path = '.'.join(['...model', submodule, model_file_name])
if importlib.util.find_spec(module_path, __name__):
model_module = importlib.import_module(module_path, __name__)
model_class = getattr(model_module, model_name)
except:
raise NotImplementedError("{} can't be found".format(model_file_name))
return model_class
[docs]def get_trainer(model_type, model_name):
r"""Automatically select trainer class based on model type and model name
Args:
model_type (~textbox.utils.enum_type.ModelType): model type
model_name (str): model name
Returns:
~textbox.trainer.trainer.Trainer: trainer class
"""
try:
return getattr(importlib.import_module('textbox.trainer'), model_name + 'Trainer')
except AttributeError:
if model_type in [ModelType.UNCONDITIONAL]:
return getattr(importlib.import_module('textbox.trainer'), 'Trainer')
elif model_type == ModelType.GAN:
return getattr(importlib.import_module('textbox.trainer'), 'GANTrainer')
elif model_type in [ModelType.SEQ2SEQ, ModelType.ATTRIBUTE]:
return getattr(importlib.import_module('textbox.trainer'), 'Seq2SeqTrainer')
else:
return getattr(importlib.import_module('textbox.trainer'), 'Trainer')
[docs]def early_stopping(value, best, cur_step, max_step, bigger=True):
r""" validation-based early stopping
Args:
value (float): current result
best (float): best result
cur_step (int): the number of consecutive steps that did not exceed the best result
max_step (int): threshold steps for stopping
bigger (bool, optional): whether the bigger the better
Returns:
tuple:
- float,
best result after this step
- int,
the number of consecutive steps that did not exceed the best result after this step
- bool,
whether to stop
- bool,
whether to update
"""
stop_flag = False
update_flag = False
if bigger:
if value > best:
cur_step = 0
best = value
update_flag = True
else:
cur_step += 1
if cur_step > max_step:
stop_flag = True
else:
if value < best:
cur_step = 0
best = value
update_flag = True
else:
cur_step += 1
if cur_step > max_step:
stop_flag = True
return best, cur_step, stop_flag, update_flag
[docs]def init_seed(seed, reproducibility):
r""" init random seed for random functions in numpy, torch, cuda and cudnn
Args:
seed (int): random seed
reproducibility (bool): Whether to require reproducibility
"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
if reproducibility:
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
else:
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = False