Python - Pytorch Lightning
Table of Contents
Pytorch Lightning is a high level framework and encapsulation on top of pytorch, which provides simplified interfaces and elegant code structure relating to AI. Detailed usage can be found in its manual.
Installation
The installation of Pytorch Lightning is very simple. Taking pip for instance, following command completes the procedure.
pip install pytorch-lightning
Usage
Neural network
The structure of neural network can be defined as a subclass of pytorch_lightning.LightningModule
, which is essentially a subclass of torch.nn.Module
. Besides the constructor __init__
and function forward
, a series of hook functions has been defined and can be simply overridden, e.g.,
configure_optimizers
can be used to further tweak the behavior of the optimizer, e.g., the initial value of learning rate and how it is decaying.on_train_start
,on_validation_start
, andon_test_start
are the hook functions provoked at the beginning of training, validating, and testing, respectively.on_train_end
,on_validation_end
, andon_test_end
are the hook functions provoked at the end of training, validating, and testing, respectively.training_step
,validation_step
, andtest_step
are the hook functions provoked for each batch in training, validating, and testing respectively, e.g., in which loss function is calculated.training_epoch_end
,validation_epoch_end
, andtest_epoch_end
are the hook functions provoked at the end of each epoch in training, validating, and testing respectively.
Data module
In python_lightning
, a data module class is defined in python_lightning.LightningDataModule
, which provides a unified interface for data loading for training, validating, and testing. We can simply define our own data module as a subclass of python_lightning.LightningDataModule
, and override the following functions.
setup
can define one or more datasets needed.train_dataloader
,val_dataloader
, andtest_dataloader
return separate dataloaders for training, validating and testing, respectively.
Procedure
The main procedure is completed based on an instance of class python_lightning.Trainer
. In its constructor, we can customize it via a series of keyword arguments, e.g.,
default_root_dir
indicates the output directory.max_epochs
is the maximum number of epochs.check_val_every_n_epoch
is the period of validation in epoch number.logger
is object for logging, e.g.,python_lightning.loggers.TensorBoardLogger
.enable_progress_bar
is a boolean indicator of progress bar.strategy
defines the training type, e.g.,DDP
,DDP2
.gpus
indicates the identities of selected GPU(s).
The training and testing can be finished by calling its member functions, e.g.,
fit
for training.test
for testing.save_checkpoint
to save the state of thepython_lightning.Trainer
object, including the coefficients or weights of the trainable parameters in the neural network.