0%

Incremental Decoding

记录一下Fairseq当中对于CNN seq2seq,Transformer之类的并行解码模型,在推理阶段的增量解码处理。

Fairseq架构

  • 在Facebook推出的seq2seq库Fairseq当中,所有模型继承FairseqEncdoerDecoder类,所有的Encoder继承FairseqEncoder类,所有的Decoder继承FairseqIncrementalDecoder类,而FairseqIncrementalDecoder继承自FairseqDecoder类。
  • FairseqEncoder类只定义了forward,reorder_encoder_out,max_positions,upgrade_state_dict,最重要的就是forward,即定义编码的前向传播过程。reorder其实在decoder中更重要,但是这里也定义了。
  • FairseqDecoder类定义了forward,extract_features,output_layer,get_normalized_probs,max_positions,upgrade_state_dict,prepare_for_onnx_export_。forward=extract_features+output_layer,即forward定义了解码出序列的整个前向过程,而extract_features只定义到获得整个decoder的state sequence。
  • Incremental Decoder额外定义了reorder_incremental_state,set_beam_size。reorder是和incremental以及beam search密切相关的,后文将详细介绍。

训练并行,推理增量

  • 像CNN seq2seq, Transformer之类的模型打破了RNN模型的顺序性,使得seq2seq架构中的编码器和解码器在训练是都可以并行训练。
  • 编码器并行训练非常显然,而解码器实际上是一个语言模型,之所以可以并行是因为在训练时采用了teacher forcing,因此语言模型的每一时间步输入在训练时我们假设是已知的,就可以一整个(Batch,Length,Hidden)的decoder input输入模型,直接训练。
  • 但是在测试(推理)阶段,每一时间步的输入由上一时间步的输出决定,无法并行操作,如果反复运行整个decoder,那么就要运行Length次,且第i次只有前i个位置的信息是有用的,剩下部分的计算完全浪费掉了,推理的效率大大降低。
  • 这个时候就需要incremental decoding,即在推理阶段,无论是CNN还是Transformer,都想RNN一样一步一步解码,每一步使用之前推理得到的信息,而不是完全从头开始计算。

CNN

  • 对于CNN,可以发现,decoder无论哪一层,第i个位置都只需要该层[i-k,i)位置上内的信息,其中k为一维卷积的窗长。因此,只需要维护一个长度为k的队列,保存各层计算出来的state,就可以复用模型之前推理得到的信息,之后再把当前的state更新到队列中。
  • 每次计算时只需要对第i个位置进行decoding,即操作(Batch,1,Hidden)的数据Length次。
  • 在代码里,FConvDecoder将输入x和incremental_state一起传给了LinearizedConvolution,这里的介绍是
    1
    2
    3
    4
    5
    6
    """An optimized version of nn.Conv1d.
    At training time, this module uses ConvTBC, which is an optimized version
    of Conv1d. At inference time, it optimizes incremental generation (i.e.,
    one time step at a time) by replacing the convolutions with linear layers.
    Note that the input order changes from training to inference.
    """
  • 即训练时使用Time-First的形式组织数据进行卷积,充分利用GPU的并行性能,在推断时,将卷积层换成相同效果的线性层,逐帧进行推断
    1
    2
    3
    4
    5
    6
    if incremental_state is None:
    output = super().forward(input) # 这里 LinearizedConvolution的父类是ConvTBC,即没有推断时,直接将整个序列送入ConvTBC
    if self.kernel_size[0] > 1 and self.padding[0] > 0:
    # remove future timesteps added by padding
    output = output[:-self.padding[0], :, :]
    return output
  • 否则,就逐层用线性层推断,并更新input buffer进而更新incremental_state
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    # reshape weight
    weight = self._get_linearized_weight()
    kw = self.kernel_size[0]
    bsz = input.size(0) # input: bsz x len x dim
    if kw > 1:
    input = input.data
    input_buffer = self._get_input_buffer(incremental_state)
    if input_buffer is None:
    input_buffer = input.new(bsz, kw, input.size(2)).zero_()
    self._set_input_buffer(incremental_state, input_buffer)
    else:
    # shift buffer
    input_buffer[:, :-1, :] = input_buffer[:, 1:, :].clone()
    # append next input
    input_buffer[:, -1, :] = input[:, -1, :]
    input = input_buffer
    with torch.no_grad():
    output = F.linear(input.view(bsz, -1), weight, self.bias)
    return output.view(bsz, 1, -1)

Transformer

  • 同样的,我们看基于自注意力的模型如何去维护一个incremental state
  • 显然,在推断第i个位置的token时,不像CNN只与窗口大小的history相关,而是与前i-1个位置相关,但是注意,前i-1个位置计算出来的key和value是不变的,是可以复用的,第i位置只生成该位置的key,value以及query,并用query查询自己以及前i-1个位置复用的key,value,因此,incremental state应该包含了key与value的信息,且维护的不是窗口大小,而是整个序列。
  • 在代码里,TransformerDecoder将当前层输入和encoder输出传给TransformerDecoderLayer,更新buffer
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    if prev_self_attn_state is not None:
    prev_key, prev_value = prev_self_attn_state[:2]
    saved_state: Dict[str, Optional[Tensor]] = {
    "prev_key": prev_key,
    "prev_value": prev_value,
    }
    if len(prev_self_attn_state) >= 3:
    saved_state["prev_key_padding_mask"] = prev_self_attn_state[2]
    assert incremental_state is not None
    self.self_attn._set_input_buffer(incremental_state, saved_state)
    _self_attn_input_buffer = self.self_attn._get_input_buffer(incremental_state)
  • 并在MultiHeadAttention里,假如incremental_state存在,将key和value设为None,后面的计算判断为None时就跳过计算
    1
    2
    3
    4
    5
    6
    7
    8
    if incremental_state is not None:
    saved_state = self._get_input_buffer(incremental_state)
    if saved_state is not None and "prev_key" in saved_state:
    # previous time steps are cached - no need to recompute
    # key and value if they are static
    if static_kv:
    assert self.encoder_decoder_attention and not self.self_attention
    key = value = None
  • 之后读取、计算、更新,代码写的很详细。
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    if saved_state is not None:
    # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
    if "prev_key" in saved_state:
    _prev_key = saved_state["prev_key"]
    assert _prev_key is not None
    prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim)
    if static_kv:
    k = prev_key
    else:
    assert k is not None
    k = torch.cat([prev_key, k], dim=1)
    if "prev_value" in saved_state:
    _prev_value = saved_state["prev_value"]
    assert _prev_value is not None
    prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim)
    if static_kv:
    v = prev_value
    else:
    assert v is not None
    v = torch.cat([prev_value, v], dim=1)
    prev_key_padding_mask: Optional[Tensor] = None
    if "prev_key_padding_mask" in saved_state:
    prev_key_padding_mask = saved_state["prev_key_padding_mask"]
    assert k is not None and v is not None
    key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
    key_padding_mask=key_padding_mask,
    prev_key_padding_mask=prev_key_padding_mask,
    batch_size=bsz,
    src_len=k.size(1),
    static_kv=static_kv,
    )

    saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim)
    saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim)
    saved_state["prev_key_padding_mask"] = key_padding_mask
    # In this branch incremental_state is never None
    assert incremental_state is not None
    incremental_state = self._set_input_buffer(incremental_state, saved_state)

Generate

  • Fairseq的模型定义了所有前向过程,至于具体选择哪个前向过程则依据训练还是推断来决定。推断使用了fairseq-generate。
  • 要完成一次seq2seq,需要指定task和model,以及其他学习超参数。其中task确定了数据集参数,建立评价指标、词典、data_batch、实例化模型等等。
  • 其中最重要的就是train_step和inference_step,我们看看inference_step
    1
    2
    3
    def inference_step(self, generator, models, sample, prefix_tokens=None):
    with torch.no_grad():
    return generator.generate(models, sample, prefix_tokens=prefix_tokens)
  • 这里的generator是一个sequence_generator对象,其中生成的部分
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    for step in range(max_len + 1):  # one extra step for EOS marker
    # reorder decoder internal states based on the prev choice of beams
    if reorder_state is not None:
    if batch_idxs is not None:
    # update beam indices to take into account removed sentences
    corr = batch_idxs - torch.arange(batch_idxs.numel()).type_as(batch_idxs)
    reorder_state.view(-1, beam_size).add_(corr.unsqueeze(-1) * beam_size)
    model.reorder_incremental_state(reorder_state)
    encoder_outs = model.reorder_encoder_out(encoder_outs, reorder_state)

    lprobs, avg_attn_scores = model.forward_decoder(
    tokens[:, :step + 1], encoder_outs, temperature=self.temperature,
    )
  • 这里做了一层emsemble model的包装,假如我们只有一个decoder模型,那么实际上forward_decoder执行的是
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    def _decode_one(
    self, tokens, model, encoder_out, incremental_states, log_probs,
    temperature=1.,
    ):
    if self.incremental_states is not None:
    decoder_out = list(model.forward_decoder(
    tokens,
    encoder_out=encoder_out,
    incremental_state=self.incremental_states[model],
    ))
    else:
    decoder_out = list(model.forward_decoder(tokens, encoder_out=encoder_out))
  • 这里可以看到是用incremental decoding逐步解码出句子

Reorder

  • 更多的详细信息可以参考这篇博文,写的非常好,甚至被官方钦点加入了代码注释里understanding-incremental-decoding-in-fairseq
  • 还有一点,就是decoder中的reorder,在这篇博文里也有提到。
  • 在推断时和训练不同的另一点就是beam search。因此我们不止维护一个缓存队列,而是beam_size个队列。
  • 那么在挑选第i个词的时候,第k个beam缓存队列的存的输入token可能是第i-1个位置时第j个beam缓存队列beam search出来的,因此需要重新排序保证一致。