© 2023 yanghn. All rights reserved. Powered by Obsidian
5.4 自定义层
要点
- 自定义参数要善于使用
nn.Parameter
类
1. 不带参数的层
import torch
import torch.nn.functional as F
from torch import nn
class CenteredLayer(nn.Module):
def __init__(self):
super().__init__()
def forward(self, X):
return X - X.mean()
2. 带参数的层
下面自定义实现了一个全连接层,支持自定义参数矩阵:
class MyLinear(nn.Module):
def __init__(self, in_units, units):
super().__init__()
self.weight = nn.Parameter(torch.randn(in_units, units))
self.bias = nn.Parameter(torch.randn(units,))
def forward(self, X):
linear = torch.matmul(X, self.weight.data) + self.bias.data
return F.relu(linear)
提示
用 nn.Parameter
来自定义参数,这样可以保留梯度,不要直接 self.weight = torch.randn(in_units, units)