© 2023 yanghn. All rights reserved. Powered by Obsidian
9.6 编码器-解码器架构
要点
1. 编码器与解码器
不管是 CNN,还是 RNN,可以统一设计为两个组件:
- 第一个组件是一个_编码器_(encoder): 它接受一个长度可变的序列作为输入,并将其转换为具有固定形状的编码状态。
- 第二个组件是_解码器_(decoder): 它将固定形状的编码状态映射到长度可变的序列。
编码器与解码器的架构,其中解码器还可以以当前变量作为输入
2. 代码实现
编码器接口中,我们只指定长度可变的序列作为编码器的输入 X
。任何继承这个 Encoder
基类的模型将完成代码实现。
from torch import nn
#@save
class Encoder(nn.Module):
"""编码器-解码器架构的基本编码器接口"""
def __init__(self, **kwargs):
super(Encoder, self).__init__(**kwargs)
def forward(self, X, *args):
raise NotImplementedError
在解码的过程中需要实现状态的初始化
#@save
class Decoder(nn.Module):
"""编码器-解码器架构的基本解码器接口"""
def __init__(self, **kwargs):
super(Decoder, self).__init__(**kwargs)
def init_state(self, enc_outputs, *args):
raise NotImplementedError
def forward(self, X, state):
raise NotImplementedError
合并编码与解码器
#@save
class EncoderDecoder(nn.Module):
"""编码器-解码器架构的基类"""
def __init__(self, encoder, decoder, **kwargs):
super(EncoderDecoder, self).__init__(**kwargs)
self.encoder = encoder
self.decoder = decoder
def forward(self, enc_X, dec_X, *args):
enc_outputs = self.encoder(enc_X, *args)
dec_state = self.decoder.init_state(enc_outputs, *args)
return self.decoder(dec_X, dec_state)