Python - Pytorch Lightning

Table of Contents

Pytorch Lightning is a high level framework and encapsulation on top of pytorch, which provides simplified interfaces and elegant code structure relating to AI. Detailed usage can be found in its manual.

Installation

The installation of Pytorch Lightning is very simple. Taking pip for instance, following command completes the procedure.

pip install pytorch-lightning

Usage

Neural network

The structure of neural network can be defined as a subclass of pytorch_lightning.LightningModule, which is essentially a subclass of torch.nn.Module. Besides the constructor __init__ and function forward, a series of hook functions has been defined and can be simply overridden, e.g.,

  • configure_optimizers can be used to further tweak the behavior of the optimizer, e.g., the initial value of learning rate and how it is decaying.
  • on_train_start, on_validation_start, and on_test_start are the hook functions provoked at the beginning of training, validating, and testing, respectively.
  • on_train_end, on_validation_end, and on_test_end are the hook functions provoked at the end of training, validating, and testing, respectively.
  • training_step, validation_step, and test_step are the hook functions provoked for each batch in training, validating, and testing respectively, e.g., in which loss function is calculated.
  • training_epoch_end, validation_epoch_end, and test_epoch_end are the hook functions provoked at the end of each epoch in training, validating, and testing respectively.

Data module

In python_lightning, a data module class is defined in python_lightning.LightningDataModule, which provides a unified interface for data loading for training, validating, and testing. We can simply define our own data module as a subclass of python_lightning.LightningDataModule, and override the following functions.

  • setup can define one or more datasets needed.
  • train_dataloader, val_dataloader, and test_dataloader return separate dataloaders for training, validating and testing, respectively.

Procedure

The main procedure is completed based on an instance of class python_lightning.Trainer. In its constructor, we can customize it via a series of keyword arguments, e.g.,

  • default_root_dir indicates the output directory.
  • max_epochs is the maximum number of epochs.
  • check_val_every_n_epoch is the period of validation in epoch number.
  • logger is object for logging, e.g., python_lightning.loggers.TensorBoardLogger.
  • enable_progress_bar is a boolean indicator of progress bar.
  • strategy defines the training type, e.g., DDP, DDP2.
  • gpus indicates the identities of selected GPU(s).

The training and testing can be finished by calling its member functions, e.g.,

  • fit for training.
  • test for testing.
  • save_checkpoint to save the state of the python_lightning.Trainer object, including the coefficients or weights of the trainable parameters in the neural network.