# @Time : 2020/11/14
# @Author : Junyi Li
# @Email : lijunyi@ruc.edu.cn
# UPDATE:
# @Time : 2022/05/06
# @Author : Hu Yiwen
# @Email : huyiwen@ruc.edu.cn
r"""
Optimizer
#####################
"""
import numpy as np
from torch.optim import Optimizer as torch_optim
[docs]class AbstractOptim:
def __init__(self, base_optimizer: torch_optim, init_lr: float):
self.optimizer = base_optimizer
self.init_lr = init_lr
self.n_steps = 0
[docs] def step(self):
self._update_learning_rate()
self.optimizer.step()
def _update_learning_rate(self):
"""Update learning rate. One just need to implement `lr` property."""
self.n_steps += 1
lr = self.lr
for param_group in self.optimizer.param_groups:
param_group['lr'] = lr
@property
def lr(self):
"""Get learning rate for current step."""
raise NotImplementedError
[docs] def state_dict(self):
return self.optimizer.state_dict(), self.n_steps
[docs] def load_state_dict(self, state_dict: tuple):
opt, self.n_steps = state_dict
self.optimizer.load_state_dict(opt)
def __getattr__(self, item: str):
"""Pass method calls e.g. `zero_grad()`.
One can override these methods by simply implementing new methods.
"""
return getattr(self.optimizer, item)
[docs]class InverseSquareRootOptim(AbstractOptim):
def __init__(self, base_optimizer: torch_optim, init_lr: float, max_lr: float, n_warmup_steps: int):
super().__init__(base_optimizer, init_lr)
self.n_warmup_steps = n_warmup_steps
self.warmup_k = (max_lr - init_lr) / n_warmup_steps
self.decay_k = max_lr * (n_warmup_steps ** 0.5)
@property
def lr(self):
if self.n_steps <= self.n_warmup_steps:
return self.init_lr + self.warmup_k * self.n_steps
else:
return self.decay_k * self.n_steps ** -0.5
[docs]class CosineOptim(AbstractOptim):
def __init__(self, base_optimizer: torch_optim, init_lr: float, max_lr: float, n_warmup_steps: int, max_steps: int):
super().__init__(base_optimizer, init_lr)
self.n_warmup_steps = n_warmup_steps
self.half_delta = (max_lr - init_lr) / 2
self.warmup_k = (max_lr - init_lr) / n_warmup_steps
self.decay_k = np.pi / (max_steps - n_warmup_steps)
@property
def lr(self):
if self.n_steps <= self.n_warmup_steps:
return self.init_lr + self.warmup_k * self.n_steps
else:
return self.init_lr + self.half_delta * (1. + np.cos(self.decay_k * (self.n_steps - self.n_warmup_steps)))
[docs]class LinearOptim(AbstractOptim):
def __init__(self, base_optimizer: torch_optim, init_lr: float, max_lr: float, n_warmup_steps: int, max_steps: int):
super().__init__(base_optimizer, init_lr)
self.n_warmup_steps = n_warmup_steps
self.init_lr = init_lr
self.max_lr = max_lr
self.warmup_k = (max_lr - init_lr) / n_warmup_steps
self.decay_k = (max_lr - init_lr) / (max_steps - n_warmup_steps) # decay to zero
@property
def lr(self):
if self.n_steps <= self.n_warmup_steps:
return self.init_lr + self.warmup_k * self.n_steps
else:
return self.max_lr - self.decay_k * (self.n_steps - self.n_warmup_steps)
[docs]class ConstantOptim(AbstractOptim):
def __init__(self, base_optimizer: torch_optim, init_lr: float, max_lr: float, n_warmup_steps: int):
super().__init__(base_optimizer, init_lr)
self.n_warmup_steps = n_warmup_steps
self.max_lr = max_lr
self.warmup_k = (max_lr - init_lr) / n_warmup_steps
@property
def lr(self):
if self.n_steps <= self.n_warmup_steps:
return self.init_lr + self.warmup_k * self.n_steps
else:
return self.max_lr