3.3 线性回归简洁实现

要点

  • 一个 epoch 的训练过程:
flowchart LR
    A("开始") --> B["随机yield一批样本"]
   B --> C("计算该批下的平均损失(向前传播)")
     C --> D("计算参数梯度(向后传播)")
     D --> E("更新参数")
     E --> |"把整个样本都遍历一遍为止"|B

关键代码

  • 随机 yield 一批样本:data.DataLoader
  • 计算该批下的损失(向前传播):nn.MSELoss()
  • 计算参数梯度(向后传播):l.backward() 这个时候参数梯度就算好了
  • 学习:优化器(这里是 SDG) trainer.step() trainer 通过传递参数构造,按照算好的梯度更新

1. 生成数据集

3.2 线性回归从零开始实现#^6ae8e0 一样,根据预设定的数据增加噪声生成训练数据

import numpy as np
import torch
from torch.utils import data
from d2l import torch as d2l

true_w = torch.tensor([2, -3.4])
true_b = 4.2
features, labels = d2l.synthetic_data(true_w, true_b, 1000)

2. 读取数据集

利用 Pytorch 的 data 模块(Pytorch 用法#^4a13ad)读取,并创造一个可以按 batch_size yield 出样本的迭代器:

def load_array(data_arrays, batch_size, is_train=True):  #@save
    """构造一个PyTorch数据迭代器"""
    dataset = data.TensorDataset(*data_arrays)
    return data.DataLoader(dataset, batch_size, shuffle=is_train)

batch_size = 10
data_iter = load_array((features, labels), batch_size)

提示

第三行代码这里是一个解包(Python中的与**用法#^67ec1f),data.TensorDataset 是允许输入多个张量构造数据集(一般一个是训练样本 X,另一个是标签 Y ),这里相当于 data_arrays 的内容当做输入的参数

3. 定义模型

我们首先定义一个模型变量 net,它是一个 Sequential 类的实例。 Sequential 类将多个层串联在一起。这里只要一层:
3.3 线性回归简洁实现-1.png|center|300

 # nn是神经网络的缩写
from torch import nn

net = nn.Sequential(nn.Linear(2, 1))
print(net) 
# 输出:
# Sequential(
#  (0): Linear(in_features=2, out_features=1, bias=True)
#) 

4. 初始化参数

通过net[0]选择网络中的第一个图层, 然后使用weight.databias.data方法访问参数

net[0].weight.data.normal_(0, 0.01) #从正态分布中选参数
net[0].bias.data.fill_(0)

以下划线结尾的方法修改参数 (2.1 数据操作#^cfee62)

5. 定义损失函数

loss = nn.MSELoss()

均方误差(Pytorch 中的损失函数#^b31bac)默认情况下,它返回所有样本损失的平均值(平均的目的是保持一致的学习率),是一个标量

6. 定义优化算法

trainer = torch.optim.SGD(net.parameters(), lr=0.03)

7. 训练

在每个迭代周期里,我们将完整遍历一次数据集(train_data),不停地从中获取一个小批量的输入和相应的标签。对于每一个小批量,我们会进行以下步骤:

num_epochs = 3
for epoch in range(num_epochs):
    for X, y in data_iter:
        l = loss(net(X) ,y)
        trainer.zero_grad() 
        l.backward()
        trainer.step()
    l = loss(net(features), labels)
    print(f'epoch {epoch + 1}, loss {l:f}')

参考文献



© 2023 yanghn. All rights reserved. Powered by Obsidian