Python - Pytorch Lightning
Table of Contents
Pytorch Lightning is a high level framework and encapsulation on top of pytorch, which provides simplified interfaces and elegant code structure relating to AI. Detailed usage can be found in its manual.
Installation
The installation of Pytorch Lightning is very simple. Taking pip for instance, following command completes the procedure.
pip install pytorch-lightning
Usage
Neural network
The structure of neural network can be defined as a subclass of pytorch_lightning.LightningModule, which is essentially a subclass of torch.nn.Module. Besides the constructor __init__ and function forward, a series of hook functions has been defined and can be simply overridden, e.g.,
configure_optimizerscan be used to further tweak the behavior of the optimizer, e.g., the initial value of learning rate and how it is decaying.on_train_start,on_validation_start, andon_test_startare the hook functions provoked at the beginning of training, validating, and testing, respectively.on_train_end,on_validation_end, andon_test_endare the hook functions provoked at the end of training, validating, and testing, respectively.training_step,validation_step, andtest_stepare the hook functions provoked for each batch in training, validating, and testing respectively, e.g., in which loss function is calculated.training_epoch_end,validation_epoch_end, andtest_epoch_endare the hook functions provoked at the end of each epoch in training, validating, and testing respectively.
Data module
In python_lightning, a data module class is defined in python_lightning.LightningDataModule, which provides a unified interface for data loading for training, validating, and testing. We can simply define our own data module as a subclass of python_lightning.LightningDataModule, and override the following functions.
setupcan define one or more datasets needed.train_dataloader,val_dataloader, andtest_dataloaderreturn separate dataloaders for training, validating and testing, respectively.
Procedure
The main procedure is completed based on an instance of class python_lightning.Trainer. In its constructor, we can customize it via a series of keyword arguments, e.g.,
default_root_dirindicates the output directory.max_epochsis the maximum number of epochs.check_val_every_n_epochis the period of validation in epoch number.loggeris object for logging, e.g.,python_lightning.loggers.TensorBoardLogger.enable_progress_baris a boolean indicator of progress bar.strategydefines the training type, e.g.,DDP,DDP2.gpusindicates the identities of selected GPU(s).
The training and testing can be finished by calling its member functions, e.g.,
fitfor training.testfor testing.save_checkpointto save the state of thepython_lightning.Trainerobject, including the coefficients or weights of the trainable parameters in the neural network.