Incremental Decoding

Record the incremental decoding processing of parallel decoding models such as CNN seq2seq and Transformer in the inference phase in Fairseq.
Fairseq Architecture
- In Facebook's seq2seq library Fairseq, all models inherit the FairseqEncoderDecoder class, all Encoders inherit the FairseqEncoder class, all Decoders inherit the FairseqIncrementalDecoder class, and FairseqIncrementalDecoder inherits from the FairseqDecoder class.
- The FairseqEncoder class only defines forward, reorder_encoder_out, max_positions, and upgrade_state_dict, with forward being the most important method, defining the forward propagation process of encoding. Reorder is actually more important in the decoder, but it is defined here as well.
- The FairseqDecoder class defines forward, extract_features, output_layer, get_normalized_probs, max_positions, upgrade_state_dict, and prepare_for_onnx_export_. forward = extract_features + output_layer, which means forward defines the entire forward process of decoding a sequence, while extract_features only defines obtaining the decoder's state sequence.
- The Incremental Decoder additionally defines reorder_incremental_state and set_beam_size. Reorder is closely related to incremental decoding and beam search, which will be detailed later.
Training Parallelization and Inference Incrementation
- Models like CNN seq2seq and Transformer break the sequentiality of RNN models, enabling the encoder and decoder in the seq2seq architecture to be trained in parallel during training.
- Parallel training of the encoder is quite obvious, and the decoder is essentially a language model that can be parallelized during training because of teacher forcing, where the input at each time step is assumed to be known. Thus, the entire decoder input of (Batch, Length, Hidden) can be directly input into the model for training.
- However, during testing (inference), the input at each time step is determined by the output of the previous time step, which cannot be parallelized. If the entire decoder is run repeatedly, it would run Length times, and only the information of the first i positions is useful in the i-th run, with the remaining calculations completely wasted, significantly reducing inference efficiency.
- At this point, incremental decoding becomes necessary. During the inference phase, whether it's CNN or Transformer, decoding is done step by step like an RNN, using information previously inferred at each step, rather than starting from scratch.
CNN
- For CNN, it can be observed that at each layer of the decoder, the i-th position only needs information from the [i-k, i) positions, where k is the window size of the one-dimensional convolution. Therefore, by maintaining a queue of length k to save the states calculated by each layer, the model can reuse information previously inferred. 
- Each calculation only needs to decode the i-th position, i.e., operate on (Batch, 1, Hidden) data Length times. 
- In the code, FConvDecoder passes input x and incremental_state to LinearizedConvolution, which is described as: - 1 
 2
 3
 4
 5
 6- """An optimized version of nn.Conv1d. 
 At training time, this module uses ConvTBC, which is an optimized version
 of Conv1d. At inference time, it optimizes incremental generation (i.e.,
 one time step at a time) by replacing the convolutions with linear layers.
 Note that the input order changes from training to inference.
 """
- During training, data is organized in a Time-First format for convolution to fully utilize GPU parallel performance. During inference, convolution layers are replaced with equivalent linear layers for frame-by-frame inference - 1 
 2
 3
 4
 5
 6- if incremental_state is None: 
 output = super().forward(input) # Here, LinearizedConvolution's parent class is ConvTBC, so if there's no inference, the entire sequence is sent to ConvTBC
 if self.kernel_size[0] > 1 and self.padding[0] > 0:
 # remove future timesteps added by padding
 output = output[:-self.padding[0], :, :]
 return output
- Otherwise, inference is done layer by layer using linear layers, and the input buffer is updated to update incremental_state - 1 
 2
 3
 4
 5
 6
 7
 8
 9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19- # reshape weight 
 weight = self._get_linearized_weight()
 kw = self.kernel_size[0]
 bsz = input.size(0) # input: bsz x len x dim
 if kw > 1:
 input = input.data
 input_buffer = self._get_input_buffer(incremental_state)
 if input_buffer is None:
 input_buffer = input.new(bsz, kw, input.size(2)).zero_()
 self._set_input_buffer(incremental_state, input_buffer)
 else:
 # shift buffer
 input_buffer[:, :-1, :] = input_buffer[:, 1:, :].clone()
 # append next input
 input_buffer[:, -1, :] = input[:, -1, :]
 input = input_buffer
 with torch.no_grad():
 output = F.linear(input.view(bsz, -1), weight, self.bias)
 return output.view(bsz, 1, -1)
Transformer
- Similarly, let's look at how self-attention-based models maintain an incremental state 
- Clearly, when inferring the token at the i-th position, it's not just related to the history of a window size like CNN, but to the first i-1 positions. However, note that the key and value computed for the first i-1 positions remain unchanged and can be reused. The i-th position only generates its own key, value, and query, and uses the query to query itself and the reusable key and value of the first i-1 positions. Therefore, the incremental state should include key and value information, maintaining not a window size, but the entire sequence. 
- In the code, TransformerDecoder passes the current layer input and encoder output to TransformerDecoderLayer, updating the buffer - 1 
 2
 3
 4
 5
 6
 7
 8
 9
 10
 11- if prev_self_attn_state is not None: 
 prev_key, prev_value = prev_self_attn_state[:2]
 saved_state: Dict[str, Optional[Tensor]] = {
 "prev_key": prev_key,
 "prev_value": prev_value,
 }
 if len(prev_self_attn_state) >= 3:
 saved_state["prev_key_padding_mask"] = prev_self_attn_state[2]
 assert incremental_state is not None
 self.self_attn._set_input_buffer(incremental_state, saved_state)
 _self_attn_input_buffer = self.self_attn._get_input_buffer(incremental_state)
- And in MultiHeadAttention, if incremental_state exists, set key and value to None, and subsequent calculations will skip when they are None - 1 
 2
 3
 4
 5
 6
 7
 8- if incremental_state is not None: 
 saved_state = self._get_input_buffer(incremental_state)
 if saved_state is not None and "prev_key" in saved_state:
 # previous time steps are cached - no need to recompute
 # key and value if they are static
 if static_kv:
 assert self.encoder_decoder_attention and not self.self_attention
 key = value = None
- Then read, calculate, and update, with detailed code - 1 
 2
 3
 4
 5
 6
 7
 8
 9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38- if saved_state is not None: 
 # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
 if "prev_key" in saved_state:
 _prev_key = saved_state["prev_key"]
 assert _prev_key is not None
 prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim)
 if static_kv:
 k = prev_key
 else:
 assert k is not None
 k = torch.cat([prev_key, k], dim=1)
 if "prev_value" in saved_state:
 _prev_value = saved_state["prev_value"]
 assert _prev_value is not None
 prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim)
 if static_kv:
 v = prev_value
 else:
 assert v is not None
 v = torch.cat([prev_value, v], dim=1)
 prev_key_padding_mask: Optional[Tensor] = None
 if "prev_key_padding_mask" in saved_state:
 prev_key_padding_mask = saved_state["prev_key_padding_mask"]
 assert k is not None and v is not None
 key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
 key_padding_mask=key_padding_mask,
 prev_key_padding_mask=prev_key_padding_mask,
 batch_size=bsz,
 src_len=k.size(1),
 static_kv=static_kv,
 )
 saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim)
 saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim)
 saved_state["prev_key_padding_mask"] = key_padding_mask
 # In this branch incremental_state is never None
 assert incremental_state is not None
 incremental_state = self._set_input_buffer(incremental_state, saved_state)
Generate
- Fairseq's models define all forward processes, and which forward process is used depends on whether it's training or inference. Inference uses fairseq-generate. 
- To complete a seq2seq, you need to specify the task and model, along with other learning hyperparameters. The task determines dataset parameters, establishes evaluation metrics, vocabulary, data batches, and model instantiation. 
- The most important parts are train_step and inference_step, let's look at inference_step - 1 
 2
 3- def inference_step(self, generator, models, sample, prefix_tokens=None): 
 with torch.no_grad():
 return generator.generate(models, sample, prefix_tokens=prefix_tokens)
- Here, the generator is a sequence_generator object, with the generation part - 1 
 2
 3
 4
 5
 6
 7
 8
 9
 10
 11
 12
 13- for step in range(max_len + 1): # one extra step for EOS marker 
 # reorder decoder internal states based on the prev choice of beams
 if reorder_state is not None:
 if batch_idxs is not None:
 # update beam indices to take into account removed sentences
 corr = batch_idxs - torch.arange(batch_idxs.numel()).type_as(batch_idxs)
 reorder_state.view(-1, beam_size).add_(corr.unsqueeze(-1) * beam_size)
 model.reorder_incremental_state(reorder_state)
 encoder_outs = model.reorder_encoder_out(encoder_outs, reorder_state)
 lprobs, avg_attn_scores = model.forward_decoder(
 tokens[:, :step + 1], encoder_outs, temperature=self.temperature,
 )
- This wraps an ensemble model. If we only have one decoder model, then forward_decoder actually executes - 1 
 2
 3
 4
 5
 6
 7
 8
 9
 10
 11
 12- def _decode_one( 
 self, tokens, model, encoder_out, incremental_states, log_probs,
 temperature=1.,
 ):
 if self.incremental_states is not None:
 decoder_out = list(model.forward_decoder(
 tokens,
 encoder_out=encoder_out,
 incremental_state=self.incremental_states[model],
 ))
 else:
 decoder_out = list(model.forward_decoder(tokens, encoder_out=encoder_out))
- Here you can see that incremental decoding is used to decode the sentence step by step 
Reorder
- For more detailed information, refer to this blog post, which is very well written and even officially endorsed by being added to the code comments understanding-incremental-decoding-in-fairseq
- There's another point about reorder in the decoder, also mentioned in this blog post.
- During inference, unlike training, beam search is used. So we maintain not just one cache queue, but beam_size number of queues.
- When selecting the i-th word, the input token stored in the k-th beam's cache queue might have come from the j-th beam's cache queue during beam search at the i-1 position. Therefore, reordering is needed to ensure consistency.