Source code for fedflow.utils.trainer.supervised_trainer

"""
Supervised Trainer
======================

A trainer used for supervised training by pytorch.
"""

__all__ = [
    "SupervisedTrainer"
]

import os
import json
import sys
import time
from collections import namedtuple

import torch
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from torch.utils.data import random_split, DataLoader


[docs]class SupervisedTrainer(object): """ A trainer used for supervised training by pytorch. After training of this trainer, there are 4 files will appeared in current dir: * history.json: the data of history(contains loss and acc of train and validate). * history.png: the chart of history * parameter.pth: the parameters of model * optimizer.pth: the parameters of optimizer """ History = namedtuple("History", ["train_loss", "train_acc", "val_loss", "val_acc", "lr"]) History.__doc__ = "Record history data during training" History.train_acc.__doc__ = "train accuracy of every epoch" History.val_acc.__doc__ = "validate accuracy of every epoch" History.train_loss.__doc__ = "train loss of every epoch" History.val_loss.__doc__ = "validate loss of every epoch" History.lr.__doc__ = "learning rate of every epoch"
[docs] def __init__(self, model, optimizer, criterion, lr_scheduler=None, *, init_model_path=None, init_optim_path=None, dataset=None, batch_size=32, epoch=50, epoch_action=None, checkpoint_interval=10, device="cuda:0", console_out=None): """ Construct a trainer. :param model: an instance of ``torch.nn.Module``. :param optimizer: an instance of pytorch optimizer. :param criterion: loss function :param lr_scheduler: an instance of ``torch.optim.lr_scheduler._LRScheduler`` or ``torch.optim.lr_scheduler.ReduceLROnPlateau`` :param init_model_path: the init model parameters path. :param init_optim_path: the init optimizer parameters path. :param dataset: the datasets used for this trainer. :param batch_size: the batch size :param epoch: the epoch :param epoch_action: when every epoch finished, the epoch_action method will be called. In this method, you can update ``lr`` etc. The follow is an example of epoch_action: >>> class EpochAction(object): >>> def __init__(self, optim): >>> super(SupervisedTrainer, self).__init__() >>> self.reduce = torch.optim.lr_scheduler.ReduceLROnPlateau(optim, mode="max") >>> >>> # The complete method signature is: >>> # def __call__(self, *, model, optimizer, criterion, lr_scheduler, >>> # train_loss, train_acc, val_loss, val_acc, lr): >>> def __call__(self, *, val_acc, *args, **kwargs): >>> self.reduce.step(val_acc) :param checkpoint_interval: the interval of save parameters, the trainer will not save parameters if this param if 0. :param device: the device used for training. :param console_out: redirect print. """ super(SupervisedTrainer, self).__init__() self.model = model self.optimizer = optimizer self.criterion = criterion self.lr_scheduler = lr_scheduler self.init_model_path = init_model_path self.init_optim_path = init_optim_path self.batch_size = batch_size self.train_dataloader, self.val_dataloader = self.__split_dataset(dataset) self.epoch = epoch self.epoch_action = epoch_action self.checkpoint_interval = checkpoint_interval self.device = device if console_out is None: self.console_out = sys.stdout else: if type(console_out) == str: self.console_out = open(console_out, "w") else: self.console_out = console_out self.start_time = int(time.time()) self.history = self.History([], [], [], [], [])
[docs] def mount_dataset(self, dataset, val_dataset=None, *, val_ratio=0.3, batch_size=32) -> None: """ mount dataset to this trainer. :param dataset: the complete dataset or train dataset. :param val_dataset: validate dataset, if it's None, this method will split validate dataset from ``dataset``. :param val_ratio: the ratio of validate dataset when split. :param batch_size: the batch size. :return: """ if dataset is None: raise ValueError("dataset cannot be None.") if batch_size is not None: if type(batch_size) != int: raise TypeError("batch_size only accepts int") self.batch_size = batch_size if val_dataset is None: self.train_dataloader, self.val_dataloader = self.__split_dataset(dataset, val_ratio) else: self.train_dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=True) self.val_dataloader = DataLoader(val_dataset, batch_size=self.batch_size, shuffle=True)
[docs] def mount_dataloader(self, train_dataloader, val_dataloader) -> None: """ Generally, this method is not recommended. Only when the ``mount_dataset`` method unmet demand, you can directly mount a ``train_dataloader`` and a ``val_dataloader``. :param train_dataloader: dataloader used for training. :param val_dataloader: dataloader used for validating. :return: """ self.train_dataloader = train_dataloader self.val_dataloader = val_dataloader
def __split_dataset(self, dataset, val_ratio=0.3) -> tuple: if dataset is None: return None, None dataset_len = len(dataset) val_len = int(val_ratio * dataset_len) train_len = dataset_len - val_len t, v = random_split(dataset, (train_len, val_len)) return (DataLoader(t, batch_size=self.batch_size, shuffle=True), DataLoader(v, batch_size=self.batch_size, shuffle=True))
[docs] def train(self) -> dict: self.__pre_train() self.__train() self.__post_train() return { "train_acc": self.history.train_acc[-1], "val_acc": self.history.val_acc[-1] }
[docs] def test(self, init_model_path, dataset=None, *, dataloader=None) -> tuple: """ calculate the predict accuracy in dataset. :param dataset: the dataset for predicting. :param dataloader: if dataloader if not None, the ``dataset`` param will be ignored. :return: a tuple ``(loss, correct, total)`` """ self.console_out.write("[INFO] Test started.") if dataloader is None: dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=True) if init_model_path is None: self.console_out.write("[WARN] test model has no pre-trained parameters.") else: if os.path.exists(init_model_path): self.console_out.write("[INFO] load model parameters.\n") model_parameters = torch.load(init_model_path, map_location=self.device) self.model.load_state_dict(model_parameters) else: self.console_out.write("[INFO] model parameters not exists.\n") self.model = self.model.to(self.device) with torch.no_grad(): loss, correct, total = self._epoch_train(dataloader) self.console_out.write("[INFO] Test ended.") return loss, correct, total
def __load_parameters(self): if self.init_model_path is not None: if os.path.exists(self.init_model_path): self.console_out.write("[INFO] load model parameters.\n") model_parameters = torch.load(self.init_model_path, map_location=self.device) self.model.load_state_dict(model_parameters) else: self.console_out.write("[INFO] model parameters not exists.\n") if self.init_optim_path is not None: if os.path.exists(self.init_optim_path): self.console_out.write("[INFO] load optim parameters.\n") optim_parameters = torch.load(self.init_optim_path, map_location=self.device) self.optimizer.load_state_dict(optim_parameters) else: self.console_out.write("[INFO] optim parameters not exists.\n") def __pre_train(self): os.makedirs("checkpoint", exist_ok=True) self.__load_parameters() self.model = self.model.to(self.device) def __train(self): for e in range(self.epoch): t_loss, t_correct, t_total = self._epoch_train(self.train_dataloader) t_acc = t_correct / t_total with torch.no_grad(): v_loss, v_correct, v_total = self._epoch_train(self.val_dataloader) v_acc = v_correct / v_total lr = self.optimizer.param_groups[0]["lr"] self.console_out.write("[%s] EPOCH %d of %d\n" % (self.__time_format(int(time.time()) - self.start_time), e + 1, self.epoch)) self.console_out.write("\tTrain Loss: %.4f, Acc: %.2f%%\n" % (t_loss, 100 * t_acc)) self.console_out.write("\tVal Loss: %.4f, Acc: %.2f%%\n" % (v_loss, 100 * v_acc)) self.console_out.write("\tLR: %f\n" % lr) self.console_out.flush() self.history.train_loss.append(t_loss) self.history.train_acc.append(t_acc) self.history.val_loss.append(v_loss) self.history.val_acc.append(v_acc) self.history.lr.append(lr) if self.epoch_action is not None: self.epoch_action(model=self.model, optimizer=self.optimizer, criterion=self.criterion, lr_scheduler=self.lr_scheduler, train_loss=t_loss, train_acc=t_acc, val_loss=v_loss, val_acc=v_acc, lr=lr) if self.lr_scheduler is not None: self.lr_scheduler.step() if self.checkpoint_interval > 0 and (e + 1) % self.checkpoint_interval == 0: torch.save(self.model.state_dict(), "checkpoint/parameter-%d.checkpoint" % (e + 1)) torch.save(self.optimizer.state_dict(), "checkpoint/oprtimizer-%d.checkpoint" % (e + 1)) def __post_train(self): with open("history.json", "w") as f: f.write(json.dumps({ "train_loss": self.history.train_loss, "val_loss": self.history.val_loss, "train_acc": self.history.train_acc, "val_acc": self.history.val_acc, "lr": self.history.lr }, indent=4)) torch.save(self.model.state_dict(), "parameter.pth") torch.save(self.optimizer.state_dict(), "optimizer.pth") self.__draw_png() def _epoch_train(self, dataloader): correct = 0 total = 0 loss_total = 0 iter_num = 0 for i, data in enumerate(dataloader): inputs, labels = data inputs = inputs.to(self.device) labels = labels.to(self.device) if torch.is_grad_enabled(): self.optimizer.zero_grad() outputs = self.model(inputs) loss = self.criterion(outputs, labels) if torch.is_grad_enabled(): loss.backward() self.optimizer.step() loss_total += loss.item() iter_num += 1 _, pred = torch.max(outputs, 1) c = (pred == labels) for i, label in enumerate(labels): correct += c[i].item() total += 1 return loss_total / iter_num, correct, total def __time_format(self, seconds): minutes = seconds // 60 seconds = seconds % 60 hours = minutes // 60 minutes = minutes % 60 return "%02d:%02d:%02d" % (hours, minutes, seconds) def __draw_png(self): xdata = range(len(self.history.train_acc)) fig = plt.figure(figsize=(16, 14)) spec = gridspec.GridSpec(ncols=2, nrows=2, wspace=0.3, hspace=0.4) # Loss fig.add_subplot(spec[0, 0]) plt.title("Loss") plt.plot(xdata, self.history.train_loss, label="train") plt.plot(xdata, self.history.val_loss, label="val") plt.grid() plt.legend() # Acc fig.add_subplot(spec[0, 1]) plt.title("Accuracy") plt.plot(xdata, self.history.train_acc, label="train") plt.plot(xdata, self.history.val_acc, label="val") plt.grid() plt.legend() # LR fig.add_subplot(spec[1, :]) plt.title("LR") plt.plot(xdata, self.history.lr) plt.grid() plt.savefig("history.png")