Reformer - Paper Reading
Reading note for reformer.
Efficiently and Economically
- The author primarily proposes two methods to reduce memory usage of Transformers, especially when processing extremely long sequences, significantly reducing computational load and improving speed.
LSH Attention
The original idea is that in Transformer's self-attention, each token as a query needs to calculate attention with all tokens in the sequence, and then weight them to obtain a representation of the current token. However, we know that attention is generally very sparse, with weights concentrated on just a few tokens. So why not calculate weights and apply weighting only on those few tokens, thereby greatly reducing the \(O(N^2)\) computational and memory overhead in self-attention?
How can we know which few tokens these are? If we could only determine this by calculating attention, how could we possibly know which tokens have high weights before computing attention? It's impossible. But in self-attention, computing weights between query and key is simply an inner product, where keys similar to the query have higher weights. The model learns attention by learning to generate correct query and key representations, and only needs to compare query and key when calculating attention.
So the problem transforms into finding a few keys similar to each query for attention calculation. How? Certainly not by calculating all and taking the top k, as that would contradict our initial goal of reducing computational complexity. Here, the author uses Local Sensitive Hashing (LSH), which means that similar vectors are more likely to be mapped to the same hash value, with multiple similar vectors essentially placed in the same "bucket". We only need to calculate self-attention within each bucket. More specifically, for two vectors \(q_1, q_2\), an LSH hash function \(h\) can achieve:
\[ for \ dis(q_1,q_2) <= d_1 , \ p(h(q_1)==h(q_2)) >= p_1 \\ for \ dis(q_1,q_2) >= d_2 , \ p(h(q_1)==h(q_2)) <= p_2 \\ \]
Existing research in related fields has various hash functions \(h\) for different distance metrics \(dis\). Evidently, our distance metric here is cosine distance, corresponding to spherical projection LSH, which projects vectors onto a b-dimensional hypersphere divided into \(n_{buckets}\) quadrants. Vectors projected into the same quadrant are in the same bucket. The specific projection hash is:
\[ h(x) = argmax[xR;-xR] \\ \]
Where \(R\) is a random projection matrix of \([d_k,b/2]\)
The next challenge is that the number of queries and keys in a bucket might not be equal, and many queries might lack keys. So the author simply shares QK by making queries and keys emerge from the same linear transformation, with keys just normalized: \(k_{j}=\frac{q_{j}}{\left\|q_{j}\right\|}\)
Chunk Operation: Instead of performing self-attention separately in each bucket, the author segments them, rearranging bucket contents into a sequence, cutting it into equal-length segments, performing self-attention within segments, and also performing attention between adjacent segments. There's some doubt here: the paper's diagram looks ideal, with buckets of almost equal size that can be compensated by adjacent segment attention. But the actual bucket sizes are unknown. Perhaps by artificially setting this, the author is imposing a prior constraint on attention learning, suggesting bucket sizes tend to be equal and match segment length.
Multi-round LSH: LSH involves probability and thus error. The author devised a clever experiment to verify LSH's restoration of original attention, finding single-round performance unsatisfactory. Therefore, multiple hash rounds are used to ensure probability, taking the union of multiple hash rounds to ensure similar vectors land in the same bucket. Taking the union instead of intersection is likely because with many buckets, hashing becomes sparse, and the probability of dissimilar vectors landing in the same bucket is far lower than similar vectors landing in different buckets. Some details here remain to be elaborated.
Causal Masking: Normal transformers do temporal masking at the decoder, but LSH scrambles sequence order, so corresponding processing is needed to ensure temporal mask correctness.
Notably, most self-attention implementations include the self in value, but in LSH, this can't be done because key and value share values, and the self is always the most similar.
Reversible Transformer
This section's idea references the paper: "The Reversible Residual Network: Backpropagation Without Storing Activations".
The basic idea is to modify the residual structure into a reversible residual structure to save GPU memory. During backpropagation, networks need to store activation values for each layer to conduct automatic differentiation, calculate each layer's derivatives, and chain-rule differentiate. Storing these activation values consumes significant GPU memory. The reversible residual idea is to split channels into two paths with mutual residuals, modifying the computational graph's topology so that path activations can be calculated from the previous layer's activations, as shown in the image:
Forward propagation process:
\[ \begin{array}{l}{y_{1}=x_{1}+\mathcal{F}\left(x_{2}\right)} \\ {y_{2}=x_{2}+\mathcal{G}\left(y_{1}\right)}\end{array} \]
Backward propagation:
\[ \begin{array}{l}{x_{2}=y_{2}-\mathcal{G}\left(y_{1}\right)} \\ {x_{1}=y_{1}-\mathcal{F}\left(x_{2}\right)}\end{array} \]
Note that calculating \(x_2\) only uses previous layer activations \(y_1,y_2\), and calculating \(x_1\) uses the previously computed \(x_1\), thus avoiding activation value storage. Although space is saved, activation functions must be recalculated, essentially trading time for space.
The original paper applied this to ResNet, saving GPU memory to enable larger batch sizes. In transformers, it can be used to train longer sequences.
In Reformer, functions \(\mathcal{F}\) and \(\mathcal{G}\) are respectively changed to self-attention and fully connected layers, corresponding to the transformer's reversible structure.
While the reversible structure eliminates layer-count impact on space complexity, the feed-forward network (FFN) in transformers, which consumes the most memory, is still influenced by sequence length. To reduce FFN memory usage, the author again employs chunking, as FFN lacks sequence dependencies and can be computed in segments. Correspondingly, reversible structure inputs and outputs are also computed in segments. For scenarios with large vocabularies, loss log-probabilities are also computed segmentally.
The author additionally notes that this saves intermediate variables during backpropagation gradient computation, not model parameters. Saving parameter memory can be achieved by transferring to CPU memory, typically uneconomical due to high data transfer overhead between CPU and GPU. However, since Reformer can process more data in each transformation, this becomes more feasible.