Sub-Linear Memory: How to Make Performers SLiM Valerii Likhosherstov 1 Krzysztof Choromanski 2 3 Jared Davis 4 5 Xingyou Song 2 Adrian Weller 1 6 arXiv:2012.11346v1 [cs.LG] 21 Dec 2020 Abstract The Transformer architecture has revolutionized deep learning on sequential data, becoming ubiquitous in state-of-the-art solutions for a wide variety of applications. Yet vanilla Transformers are notoriously resource-expensive, requiring O(L2 ) in serial time and memory as functions of input length L. Recent works proposed various linear self-attention mechanisms, scaling only as O(L) for serial computation. We perform a thorough analysis of recent Transformer mechanisms with linear self-attention, Performers, in terms of overall computational complexity. We observe a remarkable computational flexibility: forward and backward propagation can be performed with no approximations using sublinear memory as a function of L (in addition to negligible storage for the input sequence), at a cost of greater time complexity in the parallel setting. In the extreme case, a Performer consumes only O(1) memory during training, and still requires O(L) time. This discovered time-memory tradeoff can be used for training or, due to complete backwardcompatibility, for fine-tuning on a low-memory device, e.g. a smartphone or an earlier-generation GPU, thus contributing towards decentralized and democratized deep learning. 1. Introduction The Transformer architecture (Vaswani et al., 2017) has changed the landscape of deep learning for sequential data. In contrast to more conventional methods such as recurrent neural networks (Hochreiter & Schmidhuber, 1997; Cho et al., 2014), the self-attention module, responsible for temporal information propagation, is fully-parallelizable, meaning that the training speed can be increased by simply using more compute resources. 1 University of Cambridge 2 Google Brain 3 Columbia University DeepMind 5 Stanford University 6 Alan Turing Institute. Correspondence to: Valerii Likhosherstov . 4 However, this parallel-friendly structure of self-attention comes at a cost of quadratic O(L2 ) time and memory complexity, where L is the length of the Transformer’s input sequence. A recent line of work aimed to address this restriction, using either structured sparsity (Child et al., 2019), truncated back-propagation (Dai et al., 2019), clustering (Kitaev et al., 2020; Roy et al., 2020) or linear attention methods (Katharopoulos et al., 2020; Choromanski et al., 2020; Shen et al., 2018; Li et al., 2020). For a detailed overview of efficient Transformers, see (Tay et al., 2020b). We refer to the family of linear attention architectures as Performers, following Choromanski et al. (2020), since their generic kernel formulation covers all the aforementioned linear attention methods. Performers reduce time and memory complexity to linear O(L) and can provably approximate conventional quadratic Transformers (Choromanski et al., 2020), demonstrating strong performance in a systematic comparison of efficient Transformers (Tay et al., 2020a). This recent trend of feeding longer sequences into Transformers, coupled with the use of deeper models, introduces new challenges for researchers and practitioners. Whereas conventional Transformer setups benefit from large-batch optimization (You et al., 2019), long sequence modelling necessitates smaller batch sizes in order to fit the model into memory. For instance, Kitaev et al. (2020) used a batch size of 1 per TPU chip to fit 64K-long sequences into their Reformer model. Katharopoulos et al. (2020) took a batch size of 4 to fit flattened CIFAR-10 images (length 3K) into their Performer analog trained on an NVidia P40 GPU with 24GB memory. Choromanski et al. (2020) could use a batch of at most 8 protein sequences (length 8K, TrEMBL dataset) per TPU chip to train Performer. Aiming to use larger batch sizes, practitioners introduced various tricks. One of them, included in the popular Transformer library Fairseq (Ott et al., 2019) and called gradient accumulation (Ott et al., 2018), splits the batch into smaller chunks that are evaluated sequentially and then the resulting batch gradient is accumulated. As the sequence length increases, even a batch size of 1 is too big for memory rendering training impossible. This problem is especially pronounced for low-memory devices, such as earlier-generation GPUs or smartphones. Heuristics, such as chunking the input into subsegments or truncated back-propagation (Dai et al., 2019), limit gradient propa- Sub-Linear Memory: How to Make Performers SLiM Figure 1. (a) MultiHead-Att block at the rth layer and its decomposition into T(r−1) , Γ(r−1) , U(r−1) . (b) Illustration of the Algorithm 1 when r = n = 2. I-II) forward passes for n = 1, 2 respectively, only the loss value and B(n) are stored. III) backward pass start, forward computation through the slice n = 2 to build symbolic Φ(2) and update B(2) → B(1) . IV) back-propagation through Φ(2) to find ∇θ(2) L and G (1) . V,VI) the same backward iteration for n = 1. gation across the whole input, and, consequently, impair long-context pattern learning. We propose a solution based on the analysis of Performers. We discover a remarkable property: even for batch size of 1, a user can decrease memory consumption at the cost of smaller parallel bandwidth of the model. Notably, no approximations are introduced, so the obtained gradient is correct and backward-compatible. Our proposed long-sequence training algorithm can be used for training or fine-tuning on a low-memory device, thus contributing towards decentralized and democratized deep learning. The algorithm has the following advantages: 1. The parameter C, 1 ≤ C ≤ L, controls a tradeoff between the memory, scaling as O(C) in addition to a negligible input sequence storage, and parallel running time (O((L/C) log C)). When C = 1, the algorithm consumes as much memory as if a single token were fed into Performer, plus a small addition. 2. The algorithm does not introduce many additional computations: for any C, it requires as many floating point operations (FLOPs) as two full-memory forward and one backward passes plus a small addition. 3. We outline conditions when the algorithm can be extended beyond Performers. By doing so, we hope to facilitate exploration of new memory-cheap architectures to benefit deep learning more generally. We evaluate the proposed time-memory tradeoff empirically, and confirm backward-compatibility for language modelling on a copying task, Penn Treebank (Marcus et al., 1993) and Enwik8 (Mahoney, 2009) datasets.1 2. Background 2.1. Exponential and Linear Self-Attention We commence by defining exponential self-attention (Vaswani et al., 2017), a key component of the Transformer. Consider a sequence scale l ∈ {1, . . . , L} and three matrices: queries Q ∈ RL×d , keys K ∈ RL×d and values V ∈ RL×d . Then exponential self-attention is defined as a functional producing Y = Attexp (Q, K, V) ∈ RL×d , Pl ⊤ l′ =1 exp(Ql Kl′ )Vl′ ∀l ∈ {1, . . . , L} : Yl = P , (1) l ⊤ ′ l′ =1 exp(Ql Kl ) where by Zl ∈ Rd2 ×... we denote slice Zl,:,...,: of a tensor Z ∈ Rd1 ×d2 ×... . Mapping (1) is designed as a differentiable dictionary, where output at index l is a weighted average over value vectors V:l . For needs of autoregressive generative modelling, when each element depends only on previous elements of the sequence (Vaswani et al., 2017), Yl only depends on inputs at indices {1, . . . , l}. Self-attention 1 Code: https://github.com/google-research/ google-research/tree/master/performer/ models/slim_performer. Sub-Linear Memory: How to Make Performers SLiM of type (1) is a key contributor to state-of-the-art results in many applications. However, its running time and memory scale as O(L2 ). This prevents applicability of exponential self-attention to sequences of big length L ≫ d. Hence, linear self-attention methods were proposed (Katharopoulos et al., 2020; Choromanski et al., 2020; Shen et al., 2018; Li et al., 2020), where the exponent is substituted by a Euclidean inner-product. This is defined as a functional Y = Attlin (Q, K, V) ∈ RL×d , where Pl ⊤ ′ =1 Vl′ · (g(Kl′ ) g(Ql )) ∀l ∈ {1, . . . , L} : Yl = l P l ′ ⊤ l′ =1 g(Kl ) g(Ql ) Pl ( l′ =1 Vl′ × g(Kl′ )⊤ ) × g(Ql ) , (2) = Pl ( l′ =1 g(Kl′ ))⊤ g(Ql ) where “×” denotes a matrix-matrix or matrix-vector product and g : Rd → RM + is a mapping into a vector with positive elements. The positivity of the result is to guarantee that the division in (2) is well-defined and stable. In practice, M is chosen to be much smaller than L. g(·) can be chosen as a simple elementwise mapping (so that d = M ). Choromanski et al. (2020) propose a randomized form of g(·), which is an unbiased approximation to exponential self-attention (1). The second transition in (2), which is due to associativity of matrix multiplication, suggests an algorithm to compute linear self-attention efficiently in subqudratic time. For a series of tensors Z(1) , . . . , Z(n) of the same shape, by Z = (Z(i) )ni=1 we understand a tensor such that for all 1 ≤ i ≤ n Zi,:,...,: = Z(i) . By R ∈ RL×d×M , S ∈ RL×M denote a tensor and a matrix such that L R = PS((Vl × g(Kl )⊤ )L l=1 ), S = PS((g(Kl ))l=1 ), (3) Pi where PS(Z) = ( i′ =1 Zi′ )ni=1 is an operator taking a prefix sum (or a cumulative sum) along the first dimension of the input tensor Z. Next, compute ∀1 ≤ l ≤ L : Yl = (Rl × g(Ql ))/(S⊤ l g(Ql )). (4) Depending on the prefix-sum algorithm used in (3), we can obtain different complexity estimates for linear selfattention. Katharopoulos et al. (2020) propose to iterate through l = 1, . . . , L maintaining only current Rl , Sl , e l . This way, tensors and compute and store the result Y L×d×M R, PS(R) ∈ R are not stored in memory, resulting in O(L) time complexity and O(L(d + M ) + dM ) memory complexity. Katharopoulos et al. (2020) also propose a similar iterative scheme for computing gradients through (3-4); see Appendix B for a detailed discussion. Alternatively, Choromanski et al. (2020) employ a parallel prefix-sum algorithm (Ladner & Fischer, 1980; Vishkin, 2010), which, for a tensor Z ∈ RL×... , finds PS(Z) in O(log L) parallel time and O(L) memory. Applying this algorithm for computing PS(R), PS(S) and then computing (4) results in only O(log L) parallel time complexity and O(LdM ) memory consumption. 2.2. Transformer and Performer Architectures In this subsection we outline a Transformer architecture which is used for autoregressive language modelling (Parmar et al., 2018). We focus on language modelling: first, to simplify notation, while our subsequent derivations are applicable in broader setups; second, language models are a crucial class of architectures because they were shown to act as few-shot learners, e.g. the seminal GPT-2 (Radford et al., 2019) and GPT-3 (Brown et al., 2020). Let p ∈ ΣL be an input sequence of length L, where Σ is a finite alphabet. By emb(pl , l) ∈ Rdmodel , 1 ≤ l ≤ L, denote a linear combination of the pl token’s learned embedding and positional embedding of l’s position (sinusoids with different frequencies, as in (Vaswani et al., 2017)). Then Transformer is defined as a parametrized L×dmodel into mapping from X(0) = (emb(pl , l))L l=1 ∈ R (out) L×|Σ| X ∈ R through a sequence of hidden representations X(1) , . . . , X(s) ∈ RL×dmodel . More formally, X(out) = X(s) W(out) + b(out) and for each 1 ≤ r ≤ s: H(r−1) = LN(MultiHead-Att(X(r−1) )) + X(r−1) , (5) X(r) = LN(FFN(H(r−1) )) + H(r−1) , where (6) MultiHead-Att(X) = [H(1) . . . H(k) ], (7) (j) (j) (j) = Att(XWQ , XWK , XWV ), (8) FFN(H) = GeLU(HW(1) + b(1) )W(2) + b(2) . (9) (j) ∀j ≤ k : H Here Att is either Attexp or Attlin and k is the number of attention heads (dmodel = kd). W(out) ∈ Rdmodel ×|Σ| , b(out) ∈ R1×|Σ| , W(1) ∈ Rdmodel ×df f , b(1) ∈ R1×df f , W(2) ∈ Rdf f ×dmodel , b(2) ∈ R1×dmodel , (j) (j) (j) WQ , WK , WV ∈ Rdmodel ×d are trainable parameters (separate for each instance of MultiHead-Att, FFN), “+” is broadcasted rowwise when biases are added and LN is layer normalization (Ba et al., 2016), which is applied rowwise and depends on additional trainable parameters. GeLU denotes Gaussian error Linear Unit (Hendrycks & Gimpel, 2016), which is applied elementwise. We refer to the Transformer (5-9) with linear self-attention Attlin as Performer. (out) For each 1 ≤ l ≤ L − 1, Xl denotes predicted logits of the probability distribution over the next token pl+1 . Let (out) Ll (Xl ) denote a cross-entropy loss with respect to pl+1 , or zero when l = L. The minimized loss is defined as (out) L = (L − 1)−1 · (L1 (X1 (out) ) + · · · + LL (XL )). (10) The Transformer configuration (5-9) can be slightly changed in the literature: different LN(·) placement, GeLU replaced with ReLU, etc. The discussed variant (5-9) corresponds to Sub-Linear Memory: How to Make Performers SLiM GPT-2. We consider this configuration for simplicity and use it in experiments. However, as we further show, our findings can be easily extended to other modifications. 3. Low-Memory Training Algorithm 3.1. Compact Notation for Performer In this section we consider Performer: the Transformer defined by (5-9) with Att = Attlin . In light of the definition (5-9) and the algorithm for linear self-attention evaluation (3-4), the sequence of computations X(0) → X(1) → · · · → X(s) can be rewritten in the following compact form, which is more convenient for our subsequent analysis. For each 1 ≤ r ≤ s, T(r−1) , Γ(r−1) = F (r) (X(r−1) ; θ), U (r−1) (r−1) = PS(T Hence, the only place where the information is propagated across the sequence dimension is the prefix-sum operation (12). The representation (11-13) encapsulates architecture details of the Transformer inside {F (1) , G(1) , . . . , F (s) , G(s) }. In fact, the representation (11-13) holds for various possible modifications of the specification (5-9), proposed in the literature. This includes, but is not limited by the different positioning of layer normalization (Xiong et al., 2020; Vaswani et al., 2017), adding a stabilizing gating mechanism (Parisotto et al., 2019), weight sharing across layers (Lan et al., 2020) or reversible Transformer layers (Kitaev et al., 2020). Therefore, we further analyse the generic, compact notation (11-13) together with the autoregressive loss formulation (10). (11) ), X(r) = G(r) (U(r−1) , Γ(r−1) ; θ). (12) 3.2. Forward Computation (13) Suppose the memory budget is not enough to perform a complete forward pass through Performer (Equations 11-13 for r = 1, . . . , s), because the input sequence length L is too big. We show that instead we can emulate the full forward computation under the memory needed for a forward pass through the input of length C ≤ L, plus a small addition. 1 ≤ C ≤ L is arbitrary and user-defined. Here θ ∈ Rnparam is a set of all trainable parameters, T(r−1) , U(r−1) ∈ RL×D1 and Γ(r−1) ∈ RL×D2 are the following matrices (see Figure 1a for an illustration): • T(r−1) is a matrix of intermediate representations which are passed into the prefix-sum operator. That (r−1) is, for each 1 ≤ l ≤ L, Tl is a concatenation of g(Kl ) and flattened Vl × g(Kl )⊤ for all attention heads computed at the rth step (Equations 8 and 3). Consequently, D1 = M (d + 1)k. Split each matrix X(r) , T(r) , Γ(r) , U(r) , into N slices of size at most C along the vertical axis (N = ⌈L/C⌉): for each ∀1 ≤ n ≤ N , (r) Bn ×dmodel n X(r,n) = (XAn +l )B , l=1 ∈ R (r) (r−1) • For each 1 ≤ l ≤ L, Ul is a concatenation of all corresponding Sl and flattened Rl – results of the prefix-sum operation (Equation 3) inside each selfattention head (Equation 8). • Γ(r−1) is a matrix of representations which skip the (r−1) prefix-sum operation. For each 1 ≤ l ≤ L, Γl is (r−1) (j) (j) a concatenation of Xl and g(Ql ) = g(XWQ ) – query vectors for each attention head 1 ≤ j ≤ k (Equations 8 and 4). Therefore, D2 = M k + dmodel . (r) Bn ×D1 (r,n) n n T(r,n) = (TAn +l )B = (UAn +l )B , l=1 , U l=1 ∈ R (r) Bn ×D2 n Γ(r,n) = (ΓAn +l )B , l=1 ∈ R where An = (n − 1)C and by Bn , 1 ≤ n ≤ N , we denote the size of nth slice: Bu = C for u < N , BN ≤ C. Based on (11-13), we conclude that for each 1 ≤ n ≤ N and 1 ≤ r ≤ s the following recurrence holds: T(r−1,n) , Γ(r−1,n) = F (r) (X(r,n) ; θ), U (r−1,n) (r−1,n−1) ⊤ = 1Bn×(UBn−1 ) + PS(T(r−1,n) ), X(r,n) = G(r) (U(r−1,n) , Γ(r−1,n) ; θ). (r) (r) F and G are functionals parametrized by θ. That is, they take subsets of θ corresponding to rth layer weights (Equations 5-9). F (r) is responsible for constructing T(r−1) and Γ(r−1) – representations preceding prefix-sum computation, while G(r) finalizes MultiHead-Att computation (7) and includes the feed-forward block (9). Importantly, F (r) and G(r) are applied rowwise, i.e. (11, 13) can be rewritten as (r−1) ∀1 ≤ l ≤ L : Tl (r−1) , Γl (r−1) = F (r) (Xl ; θ), (14) (r) (r−1) (r−1) ∀1 ≤ l ≤ L : Xl = G(r) (Ul , Γl ; θ). (15) (16) (17) (18) Here 1Bn ∈ RBn is a vector of Bn ones and we denote (r−1,0) = 0D1 (a vector of D1 zeros). U B0 Now, instead of iterating over r = 1, . . . s and computing (11-13) for the whole sequence at once, we first iterate over n = 1, . . . , N and then iterate over r = 1, . . . , s in a nested loop to compute (16-18). As can be deduced from the (16-18), we only need to maintain the current value of (r−1,n−1) s (UBn−1 )r=1 ∈ Rs×D1 in the outer iteration over n. (r−1,n) )sr=1 ∈ Rs×D1 , 0 ≤ n ≤ N . Denote B (n) = (UBn The memory-efficient algorithm for the forward pass is Sub-Linear Memory: How to Make Performers SLiM Algorithm 1 Low-memory emulation of the forwardbackward pass. See Algorithm 2 for updateProc. Compared to notation from the text, redundant indices are dropped and tensor names are reused here and in the Algorithm 2. Input: p ∈ ΣL , θ ∈ Rnparam , C ∈ N . Output: loss L, gradient ∇θ L. Initialize L := 0, B := 0r×D1 ; for n = 1 to N do updateProc(n, False); end for Initialize ∇θ L := 0nparam , G := 0r×D1 ; for n = N to 1 do updateProc(n, True); end for Return L, ∇θ L . as follows. First, initialize L = 0 and B (0) = 0r×D1 . Then, iterate over n = 1, . . . , N and maintain the current value of B (n−1) . During each iteration, compute X(0,n) = n (emb(pAn +l , An + l))B l=1 . Then iterate over r = 1, . . . , s, (n) (r−1,n) . Fiwhere compute (16-18) and update Br = UBn (out,n) (s,n) (out) (out) nally, compute X =X W +b and update L += L(n) (X(out,n) ), where we denote L(n) (X(out,n) ) = (L − 1)−1 Bn X (out,n) LAn +l (Xl Algorithm 2 updateProc procedure. Input: n ∈ N, binary flag onBackprop . if onBackprop then Initialize Φ := 0; end if n X := (emb(pAn +l , An + l))B l=1 ; for r = 1 to s do Compute T, Γ := F (r) (X; θ); P Bn if onBackprop then Update Br −= l=1 Tl ; end if Set U := 1Bn Br⊤ + PS(T), X := G(r) (U, Γ; θ); if onBackprop then Update Φ+= Gr⊤ UBn ; else Update Br := UBn ; end if end for Set L(upd) := L(n) (XW(out) + b(out) ); if onBackprop then Update Φ+= L(upd) ; Compute ∇θ Φ, ∇B Φ through auto-differentiation; Update ∇θ L+= ∇θ Φ, G := ∇B Φ; else Set L+= L(upd) ; end if ). l=1 By the end of the iteration over n, the correct loss value (10) is computed. As a result, the forward pass takes O(L) serial time or O((L/C) log C) parallel time and consumes only O(C) memory. This is in addition to the input sequence p ∈ ΣL storage, which is O(L) in principle, however the constant is negligibly small. For instance, if p is a flattened image or an ASCII text string, then it occupies precisely L bytes in memory. The log C term in the parallel time complexity is due to the parallel prefix-sum algorithm taking logarithmic time, as discussed in Subsection 2.1. 3.3. Back-Propagation and the Final Algorithm The goal of a backward pass is to compute gradient ∇θ L of the loss function with respect to parameters θ. One can just perform automatic differentiation (Griewank & Walther, 2008) (implemented in Tensorflow (Abadi et al., 2015) and Pytorch (Paszke et al., 2017)) through the computation graph induced by the memory-efficient forward pass algorithm from Subsection 3.2. However, such backward pass would need to store all intermediate tensors produced during the forward pass, resulting in O(L) memory complexity as a function of L and C. Instead, we propose a back-propagation algorithm which has the same time and memory complexity as the efficient forward pass. Let θ(1) = · · · = θ(N ) = θ be results of a symbolic “identity operation” performed on θ, so that for all 1 ≤ n ≤ N , θ(n) is used instead of θ in (16-18). Then the total gra- dient of θ has the form ∇θ L = ∇θ(1) L + · · · + ∇θ(N ) L. In Appendix A we derive an expression for ∇θ(n) L, 1 ≤ n ≤ N . Namely, denote G (n) = ∇B(n) L, then ∇θ(n) L = ∇θ(n) Φ(n) (θ(n) , B (n−1) , G (n) ), where Φ(n) : Rnparam × Rs×D1 × Rs×D1 → R, Φ(n) (θ(n) , B (n−1) , Z) = L(n) (X(out,n) ) + s X (n) Z⊤ r Br . r=1 In Φ(n) ’s definition, X(out,n) = X(s,n) W(out) + b(out) (r−1,n) s )r=1 are results of (16-18) iteraand B (n) = (UBn tion over r = 1, . . . , s with parameters θ = θ(n) and (r−1,n−1) s )r=1 equal to Φ(n) ’s second argument B (n−1) . (UBn−1 Gradient ∇θ(n) Φ(n) can be computed by automatic differentiation through the computation graph induced by Φ(n) . An efficient way to compute and sum up all ∇θ(n) L is to iterate in a backward direction n = N, . . . , 1 and to maintain current values of B (n) , G (n) . B (N ) is known after the end of the forward pass, and for each 1 ≤ n ≤ N , B (n−1) = B (n) − Bn X (r−1,n) s )r=1 . (Tl (19) l=1 Further, in Appendix A we show that G (N ) = 0r×D1 and, for each 1 ≤ n ≤ N , G (n−1) = ∇B(n−1) Φ(n) (θ(n) , B (n−1) , G (n) ). (20) By a single auto-differentation through Φ(n) we can compute ∇θ(n) L = ∇θ(n) Φ(n) and the update (20). Sub-Linear Memory: How to Make Performers SLiM Observe that, if w is some vector of length Bn and h is some scalar function of v = PS(w), then for all 1 ≤ l ≤ Bn P Bn ′ : ∇h w l = l′ =t ∇h vl . In other words, the gradient through PS(·) is another prefix sum computed backwards. Hence, auto-differentiation through Φ(n) takes the same parallel time O(log C), serial time O(L) and memory O(C), as the forward computation of Φ(n) . Since during the whole back-propagation algorithm, we only store and update tensors B (n) , G (n) , whose size doesn’t depend on L and C, this results in total O((L/C) log C) parallel time, O(L) serial time and O(C) memory in addition to p storage. A full description of the forward-backward pass is presented in Algorithm 1. Figure 1b is an illustration of the algorithm. 3.4. Analysis of the Running Time and Memory As we have shown, Performer can be trained in parallel time O((L/C) log C) and O(C) memory in addition to the input sequence p storage. Hence, C is a tradeoff parameter: when C is maximal (C = L), the model is fully-parallelized along the sequence dimension, therefore resulting in the fastest execution. Whereas minimal C = 1 corresponds to step-by-step processing, i.e. a fully-sequential regime which doesn’t benefit from parallelized computations on GPU or TPU, but consumes O(1) memory as a function of L. It can be seen that during the forward pass, Algorithm 1 requires as many total FLOPs as the naive forward pass through (16-18). As for the backward pass, for each 1 ≤ n ≤ N , the forward pass through n’s slice is repeated for symbolical construction of Φ(n) (see Algorithm 2), and then back-propagation is run through Φ(n) . In addition, a backward update of B (n) (19) is computed, taking precisely Bn sM (d + 1)k “add” operations. Hence, we conclude that Algorithm 1 requires as many FLOPs as two forward and one backward pass through (16-18) for the whole sequence p plus LsM (d + 1)k = LsM dmodel + LsM k FLOPs. To characterize this addition, assuming that typically df f = 4dmodel in practice, observe that applying linear operators in (5-9) alone requires 3Lsd2model + 2Lsdmodel df f = 11Lsd2model FLOPs. This is much bigger than LsM dmodel + LsM k, since M is much smaller than dmodel in practice (Choromanski et al., 2020; Katharopoulos et al., 2020). were applied to a sequence of length 1 plus exactly 2sdmodel (M + 1) floats for storing B, G. For comparison, the subset of θ corresponding to matrix parameters in selfattention and feed-forward blocks (5-9), occupies 3sd2model + 2sdmodel df f = 11sd2model floats. Again, this is much bigger than 2sdmodel (M + 1), since M is much smaller than dmodel in practice. To understand these fruitful properties, we perform a conceptual comparison of Performer, recurrent neural networks (RNNs, Hochreiter & Schmidhuber (1997); Cho et al. (2014)) and residual architectures (e.g. Neural ODEs, Chen et al. (2018)), which are also used for sequence processing. The rth layer of all models has the following form for 1 ≤ l ≤ L: RNN : Residual : Performer : (r) Xl (r) (r−1) = f (r) (Xl−1 , Xl ), (21) (r) (r) (r) (r−1) Xl = Xl−1 + f (r) (Xl−1 , Xl ), (r) (r) (r−1) Xl = Xl−1 + f (r) (Xl ). (22) (23) Here f (r) is some nonlinear map. Observe that Performer is (r) (r) the only architecture where Xl depends linearly on Xl−1 . It’s not hard to see that Algorithm 1 can be applied to any architecture of type (23). Despite the update’s simplicity, Performer appears to work very well in challenging real-life setups, and, as shown by Choromanski et al. (2020), can approximate any conventional Transformer with exponential self-attention. See Table 1 for a complexity comparison of all discussed architectures and the proposed algorithm. Table 1. Complexity for the exact forward-backward pass as functions of sequence length L and the tradeoff parameter C ≤ L (for Performer). The indicated memory complexity is in addition to the input sequence p storage. The serial time complexity for Performer is reported for the version with iterative PS(·) computation (as in Katharopoulos et al., 2020), while the parallel time is reported for the parallel prefix sum (as in Choromanski et al., 2020). For both methods, memory complexity is the same, though the constant is smaller for the iterative version. M ODEL RNN R ESIDUAL NN Attexp T RANSF. P ERFORMER O UR ALGORITHM O UR ALG ., C = 1 S ERIAL PARALLEL TIME TIME O(L) O(L) O(L2 ) O(L) O(L) O(L) O(L) O(L) O(log L) O(log L) L O( C log C) O(L) M EMORY O(L) O(L) O(L2 ) O(L) O(C) O(1) Since the back-propagation takes roughly 5 times more FLOPs than the forward pass (Griewank & Walther, 2008), we conclude that memory efficiency of Algorithm 1 results in a small constant-time increase in FLOPs. The FLOPs count has a direct effect on energy consumption (Wu* et al., 2020), a crucial factor for on-device applications. 4. Experiments Further analysis of Algorithm 1 reveals that the C = 1 regime requires as much memory as if Transformer Our main contribution is a new low-memory gradient computation algorithm for the existing Performer architecture. Sub-Linear Memory: How to Make Performers SLiM Time, sec. 103 102 101 102 10 1 10 0 4.0 Config. II, iter. L/B for Config. II, iter. Config. II, PS L/B for Config. II, PS Config. IV, iter. L/B for Config. IV, iter. Config. IV, PS L/B for Config. IV, PS ∝C Config. II, iter., full Config. II, PS, full Config. IV, iter., full 1e−6 Config. II Config. III Config. IV 3.5 3.0 Relative discrepancy 104 GPU memory, GB Config. II, iter. Config. II, PS Config. III, iter. Config. III, PS Config. IV, iter. Config. IV, PS ∝ C −1 Config. II, iter., full Config. II, PS, full Config. III, iter., full Config. III, PS, full Config. IV, iter., full 2.5 2.0 1.5 100 1.0 10 10 −1 0 2 4 6 log2(C) 8 10 12 −1 0.5 0 2 4 6 log2(C) 8 10 12 0 5 log2(C) 10 Figure 2. Benchmarks of Algorithm 1. All plots are averaged over 10 seeds. “iter.” stands for iterative computation of (3-4), while “PS” is for explicit prefix sum computation in (3). We don’t report time and memory for big values of C in “Config. IV, PS” setup and for “Config. IV, full” setup, because these runs resulted in memory overflow. (Left) Time dependence on C. Crosses indicate horizontal time levels for corresponding full memory-inefficient methods. The dotted line indicates ∝ C −1 tangent in logarithmic scale. (Middle) Memory dependence on C. Again, crosses are for horizontal levels of full-sequence methods and the dotted line indicates ∝ C tangent. We do not report curves for config. III, because they completely match curves for config. IV, which is natural, since dmodel is the same for both configurations. “L/B” stands for a memory lower bound computed by processing input of length C. (Right) Relative gradient discrepancy as a function of C, also reporting standard errors. Performers have very competitive performance among other methods for long sequence modelling (Choromanski et al., 2020; Katharopoulos et al., 2020; Tay et al., 2020a). Hence, in the experimental section, we aim to answer the following questions about using this algorithm in practice: 1. Does the theoretical time-memory tradeoff, controlled by C, agree with empirical benchmarks of time and memory for C variation? 2. In precise arithmetic, different values of C lead to the same correct gradient ∇θ L. Does this hold in practice, when finite-precision arithmetic is employed? 3. Can a model, pre-trained with a bigger value of C (e.g. on a server), be fine-tuned with a smaller C (e.g. on a smartphone)? Does the parameter C affect the performance of training from scratch? We address each question in detail in the subsections below. In our experiments, we analyse 4 model configurations (L, dmodel ): I = (512, 256), II = (1024, 512), III = (4096, 1024), IV = (16384, 1024). In all configurations, we set df f = 4dmodel , k = dmodel /64 (number of heads), s = 3 (number of layers). We set M = d and employ g(x) = (x2i )di=1 elementwise-quadratic feature mapping in (2), which we find to work well in practice. In all experiments Σ = {0, . . . , 255} and batch size is set to 1, i.e. we analyse a setup where gradient accumulation cannot be used to decrease memory, and therefore our algorithm is crucial. Our code is in PyTorch 1.7. To ensure that reproduction of experiments is accessible for a wider audience, we use a single NVIDIA Tesla P100 GPU with 16 GB memory for each experiment. 4.1. Empirical Benchmarking of the Tradeoff We run Algorithm 1 for configurations II-IV and different powers of 2 as C. We use input strings sampled randomly from ΣL . In order to characterize the timememory tradeoff, we measure wall-clock time and peak GPU memory for a single gradient evaluation. We use the torch.cuda.max memory allocated function to report peak GPU memory. As discussed in Section 2.1, there are two methods to compute (3-4): the first (iterative) method doesn’t compute and store tensors (4) explicitly, resulting in smaller memory consumption at a cost of less parallelization, while the second one computes tensors (4) using the parallel prefix sum algorithm, therefore operating faster, but using more memory. The same methods can be applied for the memoryefficient algorithm when computing (17-18) updates. We implement and benchmark both methods as part of the algo- Sub-Linear Memory: How to Make Performers SLiM Bits per character, val. Accuracy, % 3.00 60 2.75 F/T start Full C = 64 C = 64 F/T C = 128 C = 128 F/T 40 20 0 2500 5000 7500 10000 12500 15000 iteration # Enwik8 F/T start Full C = 256 C = 256 F/T C = 512 C = 512 F/T 3.25 80 0 Penn Treebank 3.50 2.50 F/T start Full C = 1366 C = 1366 F/T C = 2048 C = 2048 F/T 4.0 Bits per character, val. Copying task 100 3.5 3.0 2.5 2.25 2.00 2.0 1.75 0 5000 10000 15000 20000 25000 30000 iteration # 0 20000 40000 60000 iteration # 80000 100000 Figure 3. Learning curves for three language modelling setups. We report accuracy on a newly generated data samples for Copying task, and bits-per-character metric on validation examples for Penn Treebank and Enwik8. F/T stands for “fine-tuning”. All curves are almost indistinguishable, confirming correctness and backward-compatibility of gradients computed via memory-efficient Algorithm 1. rithm. For the explicit prefix-sum method, we find that the torch.cumsum function works faster and consumes less memory than our custom implementation of the parallel prefix sum algorithm. We attribute this to hardware-optimized low-level implementation of the native function, and use this function in experiments. As for the iterative algorithm, we implement its “block” version, when, instead of iterating l one-by-one, we iterate through blocks of small size (see details in Appendix B). This way, the algorithm has a smaller constant in O(L) time complexity and bigger constant in a “small” O(dM ) term of the memory complexity (assuming that d, M ≪ L). For a fixed value of C, in addition to benchmarking memory of Algorithm 1, we also report memory of the naive gradient computation run on a string of length C, sampled uniformly from ΣC . This is to confirm that memory consumption of Algorithm 1 is just slightly above the full computation on the input of length C. Results are reported in Figure 2 (left, middle). We observe significant improvements in memory consumption compared to the full computation, as C decreases. As C converges to 20 = 1, the remaining memory consumption can be attributed to storage of the model’s parameters θ. Time follows two regimes: declining fast as C grows (meaning that prefix sums are parallelized) and declining slower for big values of C (meaning that the practical limit of parallelization is reached). Memory scales slower than O(C), as C increases. We attribute this effect to details of PyTorch internal implementation. Interestingly, we find that iterative version of (3-4) computation works only slightly slower than prefix-sum version, while consuming much less memory. Finally, Algorithm 1 consumes slightly more memory in practice than the full method run on the input of length C. 4.2. Effects of Finite-Precision Arithmetic Since the iterative version of (3-4) computation results in a good balance between time and memory of Algorithm 1, we use it in our subsequent experiments. To quantify finite-precision effects, we plot relative discrepancy (C) (f ull) (f ull) k∇θ L − ∇θ Lk2 /k∇θ Lk2 between the gradient (C) (f ull) ∇θ produced by Algorithm 1, and the gradient ∇θ L produced by full-input computation. Figure 2 illustrates results for randomly initialized models. We observe a very small discrepancy (of order 10−6 –10−5 ), confirming the correctness of Algorithm 1. The discrepancy is slightly increasing as C decreases, which can be attributed to effects of finite-precision arithmetic. 4.3. Training from Scratch and Fine-tuning To confirm backward compatibility of Algorithm 1 during training, we consider three language modelling setups: Copying task, symbol-level Penn Treebank and Enwik8. For the Copying task, we follow the setup from (Kitaev et al., 2020; Katharopoulos et al., 2020), sampling inputs as 0ω0ω, where ω is drawn uniformly from (Σ \ {0})L/2−1 . In this setup, we only aggregate cross-entropy loss from the second half of the input, so the task is to reproduce the first half. We include the Copying task as an example setup where long-range information propagation is crucial, and the heuristic of “chunking” the input into smaller segments would fail to solve the task. We use model configurations I, II, III for the Copying task, Penn Treebank and Enwik8 respectively, resulting in sequence lengths L = 512, 1024, 4096 respectively. For each setup, we compare training with full gradient computation, and training equipped with memory-efficient gradient computation via Algorithm 1 using various values of C. In addition, we consider a “fine-tuning” regime, when the first Sub-Linear Memory: How to Make Performers SLiM half of iterations is run using the full algorithm, and the second half is run using Algorithm 1. Figure 3 demonstrates results: all methods result in the same, indistinguishable performance. This confirms that memory-efficient gradient computation can be used both for training from scratch, and for fine-tuning, e.g. on a low-memory device. Table 2 quantifies the memory savings and time tradeoff in all setups. Additional experimental details and results (bigger version of Figure 3, bits-per-character for the Copying task and train set performance for Penn treebank and Enwik8) can be found in Appendix C. 5. Related Work and Extensions Compatibility with other memory-optimization techniques. Observe that the specification (11-13) is compatible with the reversible layer design from (Kitaev et al., 2020), when the sparse self-attention is replaced with the linear self-attention2 . This can bring more memory savings, since one doesn’t need to store the whole symbolic Φ(n) during the backward pass. Checkpointing techniques (Griewank, 1992; Chen et al., 2016) can also be used to reduce the memory consumption for storing Φ(n) ’s graph, though at the cost of a longer execution time. The gradient accumulation technique (Ott et al., 2018) is also compatible with Algorithm 1, i.e. one can combine both methods to “collapse” batch and sequence dimensions simultaneously. Moreover, our algorithm is compatible with distillation (Sanh et al., 2020), since it can be run on a distilled model. Table 2. Time per iteration (averaged over 1000 iterations) and peak GPU memory. CT – Copying task, PTB – Penn Treebank. S ETUP, L, C T IME PER ITER . ( SEC .) GPU ME MORY (GB) CT 512, FULL CT 512, 128 CT 512, 64 0.0474 0.0921 0.1228 0.0449 0.0425 0.0374 PTB 1024, FULL PTB 1024, 512 PTB 1024, 256 0.1377 0.2526 0.3060 0.300 0.257 0.231 E NWIK 8 4096, FULL E NWIK 8 4096, 2048 E NWIK 8 4096, 1366 0.4598 0.7922 0.8654 1.513 1.085 0.909 6. Conclusion We proposed an algorithm for memory-efficient backpropagation through a Performer. The algorithm reduces memory consumption along the sequence dimension, and can, therefore, be used for long-sequence training. The algorithm: (1) is completely backward-compatible, since it computes precise gradients and does not involve approximation, (2) does not require many additional computations, and (3) enables user control over the tradeoff between time and memory consumption. 7. Acknowledgments Comparison with (Katharopoulos et al., 2020). Katharopoulos et al. (2020) mention that a single self-attention block can be evaluated in O(1) additional memory. However, one still needs to store L intermediate states, e.g. in the feedforward block. Hence, the full memory complexity is still O(L). In contrast, our method optimizes memory consumption along the sequence dimension for the whole multilayer model. We thank Tom Weingarten and Tamas Sarlos for many fruitful discussions. Extension to Transformers with dropout. Dropout (Srivastava et al., 2014) is a popular regularization technique. It is used with Transformers when the train dataset is small enough to cause overfitting (e.g. it wasn’t used with GPT-2, trained on a massive dataset). Our algorithm can be extended to stochastic computation graphs with dropout. For that, use separate random seeds to generate dropout masks for each slice 1 ≤ n ≤ N , and reuse these seeds two times during the forward and backward pass through the nth slice. References 2 See e.g. CausalFavor class in https://github.com/ google/trax/blob/master/trax/layers/research/sparsity.py, which is compatible with the official Reformer code. Valerii Likhosherstov acknowledges support from the Cambridge Trust and DeepMind. Adrian Weller acknowledges support from The Alan Turing Institute under EPSRC grant EP/N510129/1 and U/B/000074, and the Leverhulme Trust via CFI. Abadi, M., Agarwal, A., Barham, P., Brevdo, E., Chen, Z., Citro, C., Corrado, G. S., Davis, A., Dean, J., Devin, M., Ghemawat, S., Goodfellow, I., Harp, A., Irving, G., Isard, M., Jia, Y., Jozefowicz, R., Kaiser, L., Kudlur, M., Levenberg, J., Mané, D., Monga, R., Moore, S., Murray, D., Olah, C., Schuster, M., Shlens, J., Steiner, B., Sutskever, I., Talwar, K., Tucker, P., Vanhoucke, V., Vasudevan, V., Viégas, F., Vinyals, O., Warden, P., Wattenberg, M., Wicke, M., Yu, Y., and Zheng, X. TensorFlow: Largescale machine learning on heterogeneous systems, 2015. URL https://www.tensorflow.org/. Software available from tensorflow.org. Ba, J. L., Kiros, J. R., and Hinton, G. E. Layer normalization. arXiv preprint arXiv:1607.06450, 2016. Sub-Linear Memory: How to Make Performers SLiM Brown, T. B., Mann, B., Ryder, N., Subbiah, M., Kaplan, J., Dhariwal, P., Neelakantan, A., Shyam, P., Sastry, G., Askell, A., Agarwal, S., Herbert-Voss, A., Krueger, G., Henighan, T., Child, R., Ramesh, A., Ziegler, D. M., Wu, J., Winter, C., Hesse, C., Chen, M., Sigler, E., Litwin, M., Gray, S., Chess, B., Clark, J., Berner, C., McCandlish, S., Radford, A., Sutskever, I., and Amodei, D. Language models are few-shot learners, 2020. doi: 10.1080/10556789208805505. URL https:// doi.org/10.1080/10556789208805505. Griewank, A. and Walther, A. Evaluating Derivatives: Principles and Techniques of Algorithmic Differentiation, Second Edition. Other Titles in Applied Mathematics. Society for Industrial and Applied Mathematics (SIAM, 3600 Market Street, Floor 6, Philadelphia, PA 19104), 2008. ISBN 9780898717761. URL https://books. google.co.uk/books?id=xoiiLaRxcbEC. Chen, R. T. Q., Rubanova, Y., Bettencourt, J., and Duvenaud, D. K. Neural ordinary differential equaHendrycks, D. and Gimpel, K. Bridging nonlinearities and tions. In Bengio, S., Wallach, H., Larochelle, H., stochastic regularizers with gaussian error linear units. Grauman, K., Cesa-Bianchi, N., and Garnett, R. CoRR, abs/1606.08415, 2016. URL http://arxiv. (eds.), Advances in Neural Information Processing org/abs/1606.08415. Systems, volume 31, pp. 6571–6583. Curran Associates, Inc., 2018. URL https://proceedings. Hochreiter, S. and Schmidhuber, J. Long short-term neurips.cc/paper/2018/file/ memory. Neural Comput., 9(8):1735–1780, Novem69386f6bb1dfed68692a24c8686939b9-Paper. ber 1997. ISSN 0899-7667. doi: 10.1162/neco.1997. pdf. 9.8.1735. URL http://dx.doi.org/10.1162/ neco.1997.9.8.1735. Chen, T., Xu, B., Zhang, C., and Guestrin, C. Training deep nets with sublinear memory cost. CoRR, Katharopoulos, A., Vyas, A., Pappas, N., and Fleuret, F. abs/1604.06174, 2016. URL http://arxiv.org/ Transformers are RNNs: Fast autoregressive transformers abs/1604.06174. with linear attention. arXiv preprint arXiv:2006.16236, 2020. Child, R., Gray, S., Radford, A., and Sutskever, I. GenerKingma, D. P. and Ba, J. Adam: A method for stochasating long sequences with sparse transformers. CoRR, tic optimization. In Bengio, Y. and LeCun, Y. (eds.), abs/1904.10509, 2019. URL http://arxiv.org/ 3rd International Conference on Learning Represenabs/1904.10509. tations, ICLR 2015, San Diego, CA, USA, May 7-9, Cho, K., van Merriënboer, B., Gulcehre, C., Bahdanau, 2015, Conference Track Proceedings, 2015. URL http: D., Bougares, F., Schwenk, H., and Bengio, Y. Learn//arxiv.org/abs/1412.6980. ing phrase representations using RNN encoder–decoder Kitaev, N., Kaiser, L., and Levskaya, A. Reformer: for statistical machine translation. In Proceedings of The efficient transformer. In 8th International Conferthe 2014 Conference on Empirical Methods in Natural ence on Learning Representations, ICLR 2020, Addis Language Processing (EMNLP), pp. 1724–1734, Doha, Ababa, Ethiopia, April 26-30, 2020. OpenReview.net, Qatar, October 2014. Association for Computational Lin2020. URL https://openreview.net/forum? guistics. doi: 10.3115/v1/D14-1179. URL https: id=rkgNKkHtvB. //www.aclweb.org/anthology/D14-1179. Choromanski, K., Likhosherstov, V., Dohan, D., Song, X., Gane, A., Sarlos, T., Hawkins, P., Davis, J., Mohiuddin, A., Kaiser, L., Belanger, D., Colwell, L., and Weller, A. Rethinking attention with Performers. CoRR, arXiv:2009.14794, 2020. URL https: //arxiv.org/abs/2009.14794. Dai, Z., Yang, Z., Yang, Y., Cohen, W. W., Carbonell, J., Le, Q. V., and Salakhutdinov, R. TransformerXL: Language modeling with longer-term dependency, 2019. URL https://openreview.net/forum? id=HJePno0cYm. Griewank, A. Achieving logarithmic growth of temporal and spatial complexity in reverse automatic differentiation. Optimization Methods and Software, 1(1):35–54, 1992. Ladner, R. E. and Fischer, M. J. Parallel prefix computation. J. ACM, 27(4):831–838, October 1980. ISSN 0004-5411. doi: 10.1145/322217.322232. URL https://doi. org/10.1145/322217.322232. Lan, Z., Chen, M., Goodman, S., Gimpel, K., Sharma, P., and Soricut, R. ALBERT: A lite BERT for selfsupervised learning of language representations. In International Conference on Learning Representations, 2020. URL https://openreview.net/forum? id=H1eA7AEtvS. Li, R., Duan, C., and Zheng, S. Linear attention mechanism: An efficient attention for semantic segmentation. arXiv preprint arXiv:2007.14902, 2020. Mahoney, M. Large text compression benchmark, 2009. Sub-Linear Memory: How to Make Performers SLiM Marcus, M. P., Santorini, B., and Marcinkiewicz, M. A. Building a large annotated corpus of English: The Penn Treebank. Computational Linguistics, 19(2): 313–330, 1993. URL https://www.aclweb.org/ anthology/J93-2004. Ott, M., Edunov, S., Grangier, D., and Auli, M. Scaling neural machine translation. In Proceedings of the Third Conference on Machine Translation: Research Papers, pp. 1–9, Brussels, Belgium, October 2018. Association for Computational Linguistics. doi: 10.18653/ v1/W18-6301. URL https://www.aclweb.org/ anthology/W18-6301. Ott, M., Edunov, S., Baevski, A., Fan, A., Gross, S., Ng, N., Grangier, D., and Auli, M. fairseq: A fast, extensible toolkit for sequence modeling, 2019. Parisotto, E., Song, H. F., Rae, J. W., Pascanu, R., Gulcehre, C., Jayakumar, S. M., Jaderberg, M., Kaufman, R. L., Clark, A., Noury, S., et al. Stabilizing transformers for reinforcement learning. arXiv preprint arXiv:1910.06764, 2019. Parmar, N., Vaswani, A., Uszkoreit, J., Kaiser, L., Shazeer, N., and Ku, A. Image transformer. CoRR, abs/1802.05751, 2018. URL http://arxiv.org/ abs/1802.05751. Paszke, A., Gross, S., Chintala, S., Chanan, G., Yang, E., DeVito, Z., Lin, Z., Desmaison, A., Antiga, L., and Lerer, A. Automatic differentiation in pytorch. 2017. Radford, A., Wu, J., Child, R., Luan, D., Amodei, D., and Sutskever, I. Language models are unsupervised multitask learners. OpenAI Blog, 1(8):9, 2019. Roy, A., Saffar, M., Vaswani, A., and Grangier, D. Efficient content-based sparse attention with routing transformers. arXiv, 2003.05997, 2020. Sanh, V., Debut, L., Chaumond, J., and Wolf, T. DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter. arXiv, 1910.01108, 2020. Shen, Z., Zhang, M., Yi, S., Yan, J., and Zhao, H. Factorized attention: Self-attention with linear complexities. CoRR, abs/1812.01243, 2018. URL http://arxiv.org/ abs/1812.01243. Srivastava, N., Hinton, G., Krizhevsky, A., Sutskever, I., and Salakhutdinov, R. Dropout: A simple way to prevent neural networks from overfitting. Journal of Machine Learning Research, 15(56):1929–1958, 2014. URL http://jmlr.org/papers/v15/ srivastava14a.html. Tay, Y., Dehghani, M., Abnar, S., Shen, Y., Bahri, D., Pham, P., Rao, J., Yang, L., Ruder, S., and Metzler, D. Long range arena: A benchmark for efficient transformers. arXiv, 2011.04006, 2020a. Tay, Y., Dehghani, M., Bahri, D., and Metzler, D. Efficient transformers: A survey. arXiv, 9.20006732, 2020b. Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., Kaiser, L. u., and Polosukhin, I. Attention is all you need. In Guyon, I., Luxburg, U. V., Bengio, S., Wallach, H., Fergus, R., Vishwanathan, S., and Garnett, R. (eds.), Advances in Neural Information Processing Systems 30, pp. 5998–6008. Curran Associates, Inc., 2017. URL http://papers.nips.cc/paper/ 7181-attention-is-all-you-need.pdf. Vishkin, U. Thinking in parallel: Some basic data-parallel algorithms and techniques. 2010. Wu*, Z., Liu*, Z., Lin, J., Lin, Y., and Han, S. Lite transformer with long-short range attention. In International Conference on Learning Representations, 2020. URL https://openreview.net/forum? id=ByeMPlHKPH. Xiong, R., Yang, Y., He, D., Zheng, K., Zheng, S., Xing, C., Zhang, H., Lan, Y., Wang, L., and Liu, T.-Y. On layer normalization in the transformer architecture, 2020. You, Y., Li, J., Hseu, J., Song, X., Demmel, J., and Hsieh, C. Reducing BERT pre-training time from 3 days to 76 minutes. CoRR, abs/1904.00962, 2019. URL http: //arxiv.org/abs/1904.00962. Sub-Linear Memory: How to Make Performers SLiM A. Derivation of gradient expressions θ(n) doesn’t affect terms L(1) (X(out,1) ), . . . , L(n−1) (X(out,n) ), so corresponding gradients are zero: ∇θ(n) L = ∇θ(n) N X ′ ′ L(n ) (X(out,n ) ). n′ =n Similarly, B (n) does not affect L(1) , . . . , L(n) , so N X G (n) = ∇B(n) L = ∇B(n) ′ L(n) (X(out,n ) ). n′ =n+1 In particular, G (N ) = ∇B(N ) L = 0r×D1 . ′ For all 1 ≤ n < n′ ≤ N , θ(n) and B (n−1) affect L(n ) only through B (n) , so according to the chain rule ∇θ(n) N X ′ s (n) X ∂Br ′ L(n ) (X (out,n ) ) = n′ =n+1 r=1 ′ ∀1 ≤ r ≤ s : ∇B(n−1) r′ N X L ∂θ(n) (n′ ) n′ =n+1 (X ⊤ N X × ∇B(n) r ′ ′ L(n ) (X (out,n ) ) = n′ =n+1 (out,n′ ) )= s (n) X ∂Br s (n) X ∂Br r=1 ⊤ (n−1) r=1 ∂Br ′ s (n) ⊤ X ∂Br = × ∇B(n) L, (n−1) r r=1 ∂Br ′ × ∇B(n) r N X ∂θ(n) ⊤ × ∇B(n) L, r ′ ′ L(n ) (X (out,n ) ) n′ =n+1 ∂ denotes Jacobian matrices. Further, for all 1 ≤ r ≤ s: where ∂ (n) ⊤ ∂Br ∂ × ∇B(n) L = ∇ r  [Br(n) ]⊤ hh∇B(n) Lii r  , (n−1) where  ∈ {θ(n) } ∪ {Br′ }1≤r′ ≤s . hh·ii denotes a stop-gradient operator, i.e. gradients are not propagated inside brackets and the argument is considered as constant. We conclude that ∇θ(n) L = ∇θ(n) L (n) (X (out,n) (n′ ) (out,n′ ) (n) (out,n) s (n) X ∂Br ⊤ × ∇B(n) L L (X ) = ∇θ(n) L (X )+ ∂θ(n) r=1 n′ =n+1   s X [Br(n) ]⊤ hh∇B(n) Lii = ∇θ(n) Φ(n) (θ(u) , B (n−1) , ∇B(n) L) = ∇θ(n) L(n) (X (out,n) ) + r r=1 = ∇θ(n) Φ(n) (θ(u) , B (n−1) , G (n) ), (n−1) ∀1 ≤ r′ ≤ s : Gr′ ) + ∇θ(n) N X = ∇B(n−1) L = ∇B(n−1) L(n) (X (out,n) ) + ∇B(n−1) r′ r′ r′ N X ′ r ′ L(n ) (X (out,n ) ) n′ =n+1 s (n) ⊤ X ∂Br (n) (out,n) )+ = ∇B(n−1) L (X × ∇B(n) L (n−1) r r′ r=1 ∂Br ′   s X = ∇B(n−1) L(n) (X (out,n) ) + ∇ [Br(n) ]⊤ hh∇B(n) Lii = ∇B(n−1) Φ(n) (θ(n) , B (n−1) , ∇B(n) L) r r′ r′ r=1 = ∇B(n−1) Φ(n) (θ(n) , B (n−1) , G (n) ), r′ where the second chain of equalities is equivalent to (20). Sub-Linear Memory: How to Make Performers SLiM B. Efficient “Block” Computation of (3-4) e l )L . Katharopoulos et al. (2020) propose e l )L , D = (S⊤ Q e = (g(Ql ))L , K e = (g(Kl ))L , N = (Rl × Q Denote Q l=1 l l=1 l=1 l=1 the following algorithm for computation of (3-4). Initialize buffers curR = 0d×M , curS = 0M , iterate over l = 1, . . . , L and compute e ⊤; curR := curR + Vl × K l e l; curS := curS + K e l; Nl := curR × Q e l; Dl := curS⊤ × Q Yl := Nl /Dl . This way, 3d tensor R ∈ RL×d×M is not stored in memory explicitly, resulting in O(L) time and O(L(d + M ) + dM ) memory complexity. In order to have the same memory consumption during back-propagation, Katharopoulos et al. (2020) propose the following routine. Keep buffers curR, curS as the result of forward pass, and initialize gradient buffers gradR = 0d×M , gradS = 0M . Assuming that ∇N L ∈ RL×d , ∇D L ∈ RL are computed using automatic differentiation, iterate in a backward direction l = L, . . . , 1 and compute ⊤ ∇Q e l L := (∇Dl L) · curS + curR × ∇Nl L; e ⊤; curR := curR − Vl × K l e curS := curS − Kl ; e ⊤; gradR := gradR + (∇Nl L) × Q l e gradS := gradS + (∇D L) · Ql ; l e l; ∇Vl L := gradR × K ⊤ ∇K e l L := gradR × Vl . In practice, the described algorithm works slow when implemented in pure PyTorch, because l is iterated one-by-one: Katharopoulos et al. (2020) use low-level CUDA extensions to make the algorithm practical. Instead, we propose a “block” version, when we iterate through blocks of l of a small size C (we use C = 64). In each block use explicit prefix sums on inputs of length C to find Yl:l+C−1 , using the maintained front curR, curS. The formal algorithm is as follows. Initialize buffers curR = 0d×M , curS = 0M . For simplicity assuming that C divides L (extension for an opposite case is straightforward), iterate over l = 1, C + 1, . . . , L − C + 1 and compute e ⊤ ′ )C′ ); blockR := PS((Vl+l′ −1 × K l+l −1 l =1 (24) e l+l′ −1 )C′ ); blockS := PS((K l =1 (25) blockR := (curR + blockRl′ )Cl′ =1 ; blockS := (curS + blockSl′ )Cl′ =1 ; curR := blockRC ; curS := blockSC ; e l+l′ −1 )C′ ; Nl:l+C−1 := (blockRl′ × Q l =1 C ⊤ e ′ Dl:l+C−1 := (blockS ′ × Ql+l −1 ) ′ ; l =1 C Yl:l+C−1 := (Nl+l′ −1 /Dl+l′ −1 )l′ =1 . l In the “block” version, the number of outer sequential iterations is reduced to L/C, resulting in O((L/C) log C) parallel time complexity, when the logarithmic parallel algorithm is used to compute prefix sums (24,25). In our experiments, we use torch.cumsum to compute (24,25), which works fast in practice. The memory complexity of the algorithm is O(L(d + M ) + CdM ), where the second term is for storing blockR. Assuming that C is a small constant (C = O(1)), we conclude that the “block” version has O(L(d + M ) + dM ) memory and O(L) time complexity – same as the algorithm of Sub-Linear Memory: How to Make Performers SLiM Katharopoulos et al. (2020). As for hidden constants in complexity estimates, the constant inside O(L) time complexity is reduced at the cost of increasing constant of the “small” dM term in the memory complexity (when d, M ≪ L), making the “block” iterative algorithm a practical choice for computing (3-4). We further show how to back-propagate through (3-4) in O((L/C) log C) time and O(L(d + M ) + CdM ) memory. Again, keep buffers curR, curS as the result of forward pass, and initialize gradient buffers gradR = 0d×M , gradS = 0M . Assuming that ∇N L ∈ RL×d , ∇D L ∈ RL are computed using automatic differentiation, iterate in a backward direction l = L − C + 1, L − 2C + 1, . . . , 1 and compute curR := curR − l+C−1 X l′ =l curS := curS − l+C−1 X l′ =l e ⊤′ ; Vl ′ × K l e l′ ; K e ⊤ ′ )C′ ); blockR := PS((Vl+l′ −1 × K l+l −1 l =1 blockR := (curR + blockRl′ )Cl′ =1 ; e l+l′ −1 )C′ ); blockS := PS((K l =1 blockS := (curS + blockSl′ )Cl′ =1 ; C ⊤ ∇Q e l:l+C−1 L := ((∇Dl+l′ −1 L) · blockSl′ + curRl′ × ∇Nl+l′ −1 L)l′ =1 ; gradR := gradR + l+C−1 X l′ =l gradS := gradS + l+C−1 X l′ =l e ⊤′ ; (∇Nl′ L) × Q l e l′ ; (∇Dl′ L) · Q e ⊤ ′ )C′ ); blockgradR := PS(((∇Nl+l′ −1 L) × Q l+l −1 l =1 blockgradR := (gradR − blockgradRl′ )Cl′ =1 ; e l+l′ −1 )C′ ); blockgradS := PS(((∇D ′ L) · Q l+l −1 l =1 blockgradS := (gradS − gradSl′ )Cl′ =1 ; e l+l′ −1 )C′ ∇V L := (blockgradRl′ × K l =1 ; C ⊤ ∇K e l:l+C−1 L := (blockgradRl′ × Vl+l′ −1 )l′ =1 . l:l+C−1 Finally, it’s easy to see how to use both one-to-one and “block” iterative computation as part of Algorithm 1 to compute the update (17-18). For that, when doing a forward computation for some n, r, initialize curR, curS from corresponding (r−1,n−1) subvectors of UBn −1 , with the rest of the algorithm unchanged. Similarly, during a backward pass for some n, r, initialize gradR, gradS from corresponding subvectors of G (n) and leave the rest of the iterative back-propagation algorithm unchanged. C. Additional experimental details We use 15K, 30K, 100K SGD iterations in the Copying task, Penn Treebank, Enwik8 setups respectively. We use Adam optimizer (Kingma & Ba, 2015) with β1 = 0.9, β2 = 0.999 (default configuration used in PyTorch). For the Copying task, we train with a learning rate 10−2 for 10K iterations and then decrease the learning rate to 10−3 . We use a fixed learning rate of 10−4 and 2 × 10−4 in Penn Treebank and Enwik8 experiments, respectively. Figure 4 is a bigger version of Figure 3 from the main text. Figure 5 reports additional experimental results: bits-per-character for the Copying task and train-set learning curves for Penn Treebank and Enwik8. Sub-Linear Memory: How to Make Performers SLiM Copying task 100 F/T start Full C = 64 C = 64 F/T C = 128 C = 128 F/T 80 Accuracy, % 60 40 20 0 0 2000 4000 6000 8000 iteration # 10000 12000 14000 Penn Treebank 3.50 F/T start Full C = 256 C = 256 F/T C = 512 C = 512 F/T 3.25 Bits per character, val. 3.00 2.75 2.50 2.25 2.00 1.75 0 5000 10000 15000 iteration # 20000 25000 30000 Enwik8 F/T start Full C = 1366 C = 1366 F/T C = 2048 C = 2048 F/T 4.0 Bits per character, val. 3.5 3.0 2.5 2.0 0 20000 40000 iteration # 60000 Figure 4. Bigger version of Figure 3. 80000 100000 Sub-Linear Memory: How to Make Performers SLiM Copying task F/T start Full C = 64 C = 64 F/T C = 128 C = 128 F/T 8 Bits per character 6 4 2 0 0 2000 4000 6000 8000 iteration # 10000 12000 14000 Penn Treebank F/T start Full C = 256 C = 256 F/T C = 512 C = 512 F/T 3.50 3.25 Bits per character, train 3.00 2.75 2.50 2.25 2.00 1.75 1.50 0 5000 10000 15000 iteration # 20000 25000 30000 Enwik8 F/T start Full C = 1366 C = 1366 F/T C = 2048 C = 2048 F/T 4.0 Bits per character, train 3.5 3.0 2.5 2.0 1.5 0 20000 40000 iteration # 60000 80000 100000 Figure 5. Bits-per-character learning curve for the Copying task and train-set learning curves for language modelling on Penn Treebank and Enwik8 respectively.