Python - Pytorch
张量
在pytorch中,神经网络的输入、输出、以及网络的参数等都采用了“张量”的数据结构。张量与numpy中的多维数组(numpy.ndarray)非常类似,区别在于张量可以在GPU或其他硬件上运行。
- 成员函数名称末尾加上
_,会采用替换模式,也即该成员函数直接修改张量对象本身。 .item()返回其具体数值,该方法仅适用于大小为1(也即只含有一个元素)的张量。.view()功能为reshape该张量(保持元素个数不变),其参数为新的shape。特别地,值为-1的维度可以通过元素总数不变推导得出。- 定义一个张量时,可以通过参数
device参数指定该张量存储于CPU还是GPU。也可以通过其成员函数.to()来转换。 - 定义一个张量时,参数
requires_grad=True表示需要计算并存储对于该张量的梯度。该张量的函数均带有grad_fn属性。对于其任一函数,当调用其成员函数.backward()时,该函数对于该张量自变量的梯度将会被计算并保存在该张量的.grad属性中。 - 有时候需要停止跟踪记录某个张量(记为
x)的梯度(比如在更新网络参数的权重系数时),可以采用以下任一方式:- 调用该张量的成员函数
x.requires_grad_(False)。 - 调用该张量的成员函数
x.detach()会产生一个新的张量,其属性requires_grad=False。 - 采用with语句,也即
with torch.no_grad():。
- 调用该张量的成员函数
与 numpy.ndarray 相互转换
- 张量的成员函数
.numpy()返回对应的numpy.ndarray。但是,该方法仅适用于存储在CPU的张量。 - 函数
torch.from_numpy()可以将一个numpy.ndarray转换为张量。
值得注意的是,当一个张量存储在CPU上时,该张量与其对应的 numpy.ndarray 是连接着的,也即对其中一个的修改会导致另一个也会被修改。
神经网络
典型训练过程
- 构建含有一些可学习参数(或权重)的神经网络。
- 遍历训练数据集。
- 让输入通过网络得到输出。
- 计算预测误差(或损失函数)。
- 通过反向传播,计算并存储网络各个参数的梯度。
- 根据各个参数的梯度对其进行调整,如
weight = weight - learning_rate x gradient。
正向传播与反向传播
神经网络的训练包括两个步骤
- 正向传播
- 神经网络根据输入对输出进行最佳预测。
- 反向传播
- 神经网络根据预测误差调整其参数。
当调用误差张量的成员函数 .backward() 时,开始反向传播,整个网络将被微分。Pytorch的自动差分引擎 torch.autograd 会对计算每个属性 requires_grad=True 的网络参数的梯度并将其累积在该网络参数的 .grad 属性中。所以,在反向传播开始之前,需要清空神经网络各个参数和反向传播的梯度缓冲区。
然后,需要从 torch.optim 中加载一个优化器(如随机梯度下降,SGD)。通过调用该优化器的 .step() 方法启动梯度下降,优化器会根据每个网络参数的梯度(存储于其 .grad 属性中)来对其进行调整。
定义网络
可以使用 torch.nn 包来构建神经网络。
- 基类
torch.nn.Module包含- 网络的各个层(作为类属性),如输入层、输出层、隐藏层等。
- 成员函数
.forward。 - 成员函数
.zero_grad,用于清空神经网络中所有参数的梯度缓冲区。相当于整个网络中属性requires_grad=True的张量均调用成员函数.grad.zero_()。
- 构建神经网络可以通过定义以
torch.nn.Module为基类的派生类完成。 - 只需要在派生类中重载成员函数
.forward,torch.autograd就会自动为网络各个参数定义.backward()函数。 - 所构建的神经网络的可学习参数由所述派生类的成员函数
.parameters()返回。 torch.nn包中定义了很多损失函数(如torch.nn.MSELoss)。
下面是一个非常简单的神经网络。该神经网络只包含一个全连接层和一个激活函数relu。包括两个成员函数, TRAIN 和 TEST 。顾名思义,前者用于从训练数据集中加载数据,并训练神经网络中的可学习参数,并把所有可学习参数的最终权重(训练结果)保存在文件中;后者从权重文件中读取并加载所有可学习参数的权重,然后从测试数据集中加载数据,计算误差,从而完成对训练结果(可学习参数的权重)的测试。
class NEURAL_NETWORK(torch.nn.Module):
def __init__(self, IN_FEATURES, OUT_FEATURES):
super(NEURAL_NETWORK, self).__init__()
self.FC = torch.nn.Module.Linear(IN_FEATURES, OUT_FEATURES)
def forward(self, INPUT):
return torch.nn.function.relu(self.FC(INPUT))
def TRAIN(self, DATALOADER_FOR_TRAIN):
OPT = torch.optim.Adam(self.parameters(), LEARNING_RATE)
LEARNING_RATE_SCHEDULER = torch.optim.lr_scheduler.ExponentialLR(OPT, gamma=0.9)
for EPOCH_INDEX in range(EPOCH_NUM):
for SAMPLE in DATALOADER_FOR_TRAIN:
SAMPLE = SAMPLE.to(torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"))
VALUE_PREDICTED = self.forward(SAMPLE)
VALUE_IDEAL = CALC_VALUE_IDEAL(SAMPLE)
LOSS = torch.nn.MSELoss(VALUE_PREDICTED, VALUE_IDEAL)
LOSS.requires_grad_(True)
OPT.zero_grad()
LOSS.backward()
OPT.step()
LEARNING_RATE_SCHEDULER.step()
torch.save(self.state_dict(), "NEURAL_NETWORK_WEIGHT.pth")
def TEST(self, DATALOADER_FOR_TEST):
self.load_state_dict(torch.load("NEURAL_NETWORK_WEIGHT.pth", False))
with torch.no_grad():
for SAMPLE in DATALOADER_FOR_TEST:
SAMPLE = SAMPLE.to(torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"))
VALUE_PREDICTED = self.forward(SAMPLE)
VALUE_IDEAL = CALC_VALUE_IDEAL(SAMPLE)
LOSS = torch.nn.MSELoss(VALUE_PREDICTED, VALUE_IDEAL)
数据集
自定义的数据集可以通过定义以(torch.utils.data.Dataset)为基类的派生类完成。具体地,需要重载如下三个函数:
__init__(self): 载入数据;__getitem__(self, index): 根据索引返回数据集中相应的样本点;__len(self)__: 返回数据集的大小,也即数据集中样本点的数量。
Pytorch支持多种格式的数据集,下面是一个从HDF5文件中加载数据集的例子。
class DATASET(torch.utils.data.Dataset):
def __init__(self, FILE_NAME, SAMPLE_NUM):
super(DATASET, self).__init__()
f = h5py.File(FILE_NAME, 'r')
self.sample = torch.tensor(f.get("GROUP_NAME"))
self.sample_num = SAMPLE_NUM
f.close()
def __getitem__(self, INDEX):
return self.sample[INDEX]
def __len__(self):
return self.sample_num
数据加载器
基于自定义的数据集,可以使用类 torch.utils.data.DataLoader 来生成相应的数据加载器。如
from torch.utils.data import Dataset, DataLoader dataloader = DataLoader(dataset=DATASET, batch_size=BATCH_SIZE, shuffle=SHUFFLE, num_workers=NUM_WORKERS)
其中,
DATASET为自定义的数据集;BATCH_SIZE为每个batch中含有的样本点个数;SHUFFLE为布尔型变量,用于指定数据集是否会被随机打乱顺序;NUM_WORKERS为用于数据加载的线程个数。