KDEformer: Accelerating Transformers via Kernel Density Estimation Amir Zandieh 1 Insu Han * 2 Majid Daliri * 3 Amin Karbasi 2 Abstract because naïve exact computation of their attention layers incurs quadratic (in sequence length) runtime and memory complexities. This can inhibit the training of large-scale long-sequence models. Dot-product attention mechanism plays a crucial role in modern deep architectures (e.g., Transformer) for sequence modeling, however, naïve exact computation of this model incurs quadratic time and memory complexities in sequence length, hindering the training of long-sequence models. Critical bottlenecks are due to the computation of partition functions in the denominator of softmax function as well as the multiplication of the softmax matrix with the matrix of values. Our key observation is that the former can be reduced to a variant of the kernel density estimation (KDE) problem, and an efficient KDE solver can be further utilized to accelerate the latter via subsampling-based fast matrix products. Our proposed KDEformer can approximate the attention in sub-quadratic time with provable spectral norm bounds, while all prior results merely provide entry-wise error bounds. Empirically, we verify that KDEformer outperforms other attention approximations in terms of accuracy, memory, and runtime on various pre-trained models. On BigGAN image generation, we achieve better generative scores than the exact computation with over 4× speedup. For ImageNet classification with T2T-ViT, KDEformer shows over 18× speedup while the accuracy drop is less than 0.5%. Several algorithms have been proposed to improve Transformers’ efficiency via approximating the softmax matrices in their attention layers with either sparse matrices (Kitaev et al., 2020; Daras et al., 2020; Roy et al., 2021; Sun et al., 2021) or low-rank matrices (Choromanski et al., 2021; Katharopoulos et al., 2020), or a combination of both (Chen et al., 2021b; Zaheer et al., 2020; Chen et al., 2021a; Dass et al., 2022). However, all prior advances solely focused on point-wise approximating the entries of the softmax matrix and fail to provide rigorous approximation guarantees on the final output of the attention mechanism. In this work, we design algorithms to approximate the output matrix of attention layers with provable spectral norm guarantees. 1.1. Problem Formulation and Setting. Let n be the number of tokens in the input sequence and d be the dimension of latent representations. The dot-product attention (Vaswani et al., 2017) is a mapping which takes inputs Q, K, V ∈ Rn×d (interpreted as queries, keys, and values of a dictionary) and outputs the following matrix: Att(Q, K, V) := D−1 AV  √  A := exp QK⊤ / d , D := diag(A1n ), where exp(·) is applied in an element-wise manner, 1n is the ones vector in Rn , and diag(·) maps its input vector to a diagonal matrix. We refer to A ∈ Rn×n as the attention matrix and to D−1 A as the softmax matrix. Exact computation of the attention matrix A takes Θ(n2 d) operations and storing it requires Θ(n2 ) memory. Thus, naïve computation of Att(Q, K, V) requires Ω(n2 d) runtime and Ω(n2 ) memory. Our aim is to approximate the output matrix Att(Q, K, V) efficiently while preserving its spectral structure. 1. Introduction Transformers (Vaswani et al., 2017) have been successfully applied to a wide variety of learning tasks in areas such as natural language processing (Devlin et al., 2018; Yang et al., 2019; Brown et al., 2020; Raffel et al., 2020), computer vision (Carion et al., 2020; Dosovitskiy et al., 2021), and time series forecasting (Zhou et al., 2021). Although popular, these models face serious scalability limitations * Equal contribution 1 Max-Planck-Institut für Informatik, Germany 2 Yale University, USA 3 New York University, USA. Correspondence to: Amir Zandieh . Our approach is based on reducing the number of columns of matrix A using importance sampling. We also devise an efficient estimator for the diagonal scaling matrix D, which bypasses exact and explicit computation of matrix A. Formally, for any given ε > 0 and any Q, K, V ∈ Rn×d , we want to quickly find a sampling matrix Π ∈ Rm×n with a Proceedings of the 40 th International Conference on Machine Learning, Honolulu, Hawaii, USA. PMLR 202, 2023. Copyright 2023 by the author(s). 1 KDEformer: Accelerating Transformers via Kernel Density Estimation th −1 see that the squared  norm ofthe i column of D A is P −2 2 √ ⟨qj , ki ⟩ , which is a weighted expoj∈[n] Dj,j exp d  nential kernel density with weights D−2 i,i i∈[n] and dataset small number m = n1−Ω(1) of rows along with a diagonal e ∈ Rn×n , such that the following bound on the matrix D operator norm of the error is satisfied: e −1 AΠ⊤ · ΠV Att(Q, K, V) − D op √ ≤ ε · ∥V∥op . (1) 2 √ 2 · Q at query point d1/4 · ki . Therefore, if we could estimate this weighted exponential kernel density up to some constant multiplicative factor, we could generate a sampling matrix Π with small number of samples that satisfies Eq. (1). d1/4 −1 Note that D A is a row-stochastic (transition) matrix, so its operator norm is D−1 A op = 1, thus the r.h.s. in Eq. (1) is in fact equal to ε · D−1 A op ∥V∥op . Thus, having a generalized KDE procedure for efficiently evaluating the weighted exponential kernel density, enables us to approximate Att(Q, K, V) as per Eq. (1). While there is no prior solution for this problem, we show how to translate it to the Gaussian KDE problem, which has witnessed significant recent progress, by applying appropriate transformations on K and Q (see Algorithm 2 and Theorem 3.4). Given a sampling matrix Π with m rows, we can compute the matrix product AΠ⊤ · ΠV in O(nmd) total runtime and O(nm) memory because we only need to compute the m sampled columns of A. Therefore, our main goal is to generate a sampling matrix Π with a small number e which satisfy of samples along with a diagonal matrix D Eq. (1) using a sub-quadratic runtime in n. Our Theoretical Results. We give an algorithm that oute ∈ Rn×n and a sampling matrix Π ∈ puts a diagonal D  m×n R with m = O ε−2 log n · srank(D−1 A) samples which satisfy the spectral bound of Eq. (1) with high probability in n, where srank(D−1 A) denotes the stable rank of the softmax matrix. Our method reduces the memory of  attention layers to mn = O ε−2 n log n · srank(D−1 A) . Furthermore, if the Gaussian KDE  is supported by an algorithm with runtime O ε−2 d/e µτ for relative error 1 + ε, and density lower bound µ e, then  our algorithm’s runtime is bounded by O ε−2 d · n1+τ for any datasets of queries 2 Qand keys K with diameter maxi,j∈[n] ∥ki − qj ∥2 = √ o d · log n , which is strongly sub-quadratic in n. The All prior approximate attention methods have solely foe such cused on finding an approximate attention matrix A e that A − A is small, even though A is not the ultimate F output of attention and the output depends on V in addition to A. In contrast, we propose the first efficient algorithm for approximating the output matrix Att(Q, K, V) with spectral bounds as per Eq. (1) (see Section 3.3). 1.2. Our Techniques and Results We leverage the line of work on efficient Kernel Density Estimation (KDE) (Schölkopf et al., 2002; Joshi et al., 2011; Charikar & Siminelakis, 2017; Backurs et al., 2018; 2019; Siminelakis et al., 2019). In the KDE problem, we are given a dataset X = {x1 , x2 , . . . xn } and a kernel function P k(·, ·) and aim to compute the kernel density n µX (q) = n1 i=1 k(q, xi ) for an arbitrary query point q. The goal of existing methods in the literature is to estimate  this value to (1 + ε) relative error in time O ε−2 d/e µτ for some τ > 0, where µ e is a lower bound on µX (q). Particularly, the best-known algorithm for the Gaussian kernel, due to Charikar et al. (2020), achieves τ = 0.173 + o(1). current best value for τ is τ = 0.173+o(1) due to (Charikar et al., 2020) and any future progress on Gaussian density evaluation immediately improves our method’s runtime. This result applies to a wide range of practical scenarios where the dimension d is not too large. To see why, note that entries of K, Q are typically constant, thus, the diam2 eter is maxi,j∈[n] ∥ki − qj ∥2 = O(d). Therefore, for any 2 log n dimension d = o(log2 n), e.g., d ≈ log log n , our method  −2 1+τ needs only O m + ε d · n operations, which is significantly faster than exact computation of Att(Q, K, V). We show that finding the sampling matrix Π and diagonal e which satisfy Eq. (1) can be reduced to a generalscaling D ization of the KDE problem. First note that the ithdiagonal  Pn ⟨qi ,kj ⟩ √ entry of the scaling matrix D is Di,i = j=1 exp , d which is indeed the kernel density corresponding to exponential kernel function k(x, y) = exp(⟨x, y⟩) and dataset 1 1 · K at query point d1/4 · qi . Thus, if we had an efficient d1/4 KDE procedure for estimating the exponential kernel density up to a multiplicative (1 ± ε) factor, we could compute e that satisfies the spectral guarantee of Eq. (1). a scaling D Our Practical Results. Our necessary number m of samples depends on the stable rank of the softmax matrix. To reduce m, we employ Locality Sensitive Hashing (LSH) to extract the heavy elements of D−1 A and then show that, in practice, the residual has a significantly smaller stable rank than the original matrix (see Section 3.4). With this heuristic improvement, we verify that our proposed algorithm outperforms popular attention approximations. In particular, it can save memory space up to 19.06× when the sequence length n is 16,394. We apply our method to image generation with BigGAN (Brock et al., 2019) and observe that our images, shown in Fig. 1, look more natural Additionally, to design an efficient sampling matrix Π that satisfies Eq. (1) with small number of rows, the sampling probabilities need to be proportional to the column norms of the softmax matrix D−1 A (Zouzias, 2013). One can 2 KDEformer: Accelerating Transformers via Kernel Density Estimation Exact 2. Preliminaries and Notations For any matrix A, we let ai be its ith row vector and its ∥A∥2 stable rank is defined as srank(A) := ∥A∥2F which is KDEformer op Performer always upper bounded by the algebraic rank. We denote e1 , e2 , . . . en by the standard basis vectors in Rn and 1n and 0n by the all-ones and all-zeros vectors in Rn . For vectors x, y their direct sum is denoted by x ⊕ y := [x⊤ , y ⊤ ]⊤ . Reformer Gaussian KDE. Our main algorithm is tightly related to the Gaussian KDE, where one is given a dataset X ∈ Rn×d and wants to build a data-structure (DS) such that given this DS one can estimate the following kernel density value up to (1 + ε) relative error for any query point q ∈ Rd : Figure 1. Image generations by the pre-trained BigGAN using exact and approximate attention without fine-tuning. µX (q) := 1 X 2 exp(− ∥q − xi ∥2 /2). n (2) i∈[n] The naïve method without any DS requires Θ(nd) time and memory complexities. The aim is to minimize the memory needed to store the DS and the query time, ultimately being sublinear in n. The pre-processing time which is needed to construct the DS is also desired to be small. There have been significant advances on this problem and the current best result was proposed by Charikar et al. (2020) as follows: than others and our generative score is even better than the exact attention. Furthermore, for ImageNet classification with Vision Transformer (Yuan et al., 2021), KDEformer shows 18× speedup and 82.08% accuracy which is only 0.5% lower than the exact attention (see Section 4.3). Finally, we demonstrate our method on end-to-end training under the Long Range Arena benchmark (Tay et al., 2021) and observe up to 8× speedup on wall-clock time than the exact attention (see Section 4.4). Theorem 2.1 (Fast Gaussian KDE, Theorem 2 in (Charikar et al., 2020)). Let τ = 0.173 + o(1). For any dataset X ∈ Rn×d and any ε, µ e ∈ (0, 1), there exist: 1.3. Prior Work Several popular methods try to approximate the heavy entries of the attention matrix A by restricting the attention to local neighbors of queries using Locality Sensitive Hashing (LSH) (Kitaev et al., 2020; Chen et al., 2020; Sun et al., 2021) or k-means clustering (Daras et al., 2020; Roy et al., 2021). Such approaches, however, only provide error bounds on the attention matrix, e.g., guarantees of the e F < εn, and cannot provide any provable form ∥A − A∥ guarantees for the final output matrix Att(Q, K, V). Remarkably, at the core of our algorithm, there are invocations of the Gaussian KDE primitive from Charikar et al. (2020), which heavily employs LSH to estimate kernel densities. In contrast to previous works, our algorithm uses LSH in a more subtle way, that is for estimating the right sampling probabilities in order to generate Π and also to approximate the scaling D. This difference of approach allows us to approximate Att(Q, K, V) with spectral norm guarantees. 1. Procedure P REPROCESS KDE(X, ε, µ e) constructs  a data-structure named DSkde in time O ε−2 dn/e µτ . 2. Given DSkde , any query q ∈ Rd , and µX (q) defined as in Eq. (2), procedure Q UERY KDE(DSkde , q) approximates the quantity µX (q) · 1{eµ≤µX (q)} up to (1 + ε) −τ relative error in time O(ε−2 d (e µ + µX (q)) ). The density lower bound µ e required by Theorem 2.1 is unknown to us in advance and we learn this quantity adaptively in Algorithm 2. We show in Section 3.3 that for datasets with bounded diameter µ e = n−1−o(1) . 3. Efficient Attention with Spectral Bounds In this section, we design KDEformer which can efficiently e compute a sampling matrix Π and a diagonal scaling D satisfying Eq. (1). We start by showing that this can be done very efficiently given access to a primitive for estimating the row-norms of the attention matrix A as well as the columnnorms of the softmax matrix D−1 A. Next, in Section 3.2, we present a reduction from norm estimators for A and D−1 A to the Gaussian KDE problem which has an efficient solution. Finally, we prove our main result in Section 3.3 Another recent line of work is based on approximating the attention matrix A via random feature maps of the Gaussian or exponential kernels (Choromanski et al., 2021; Katharopoulos et al., 2020). Chen et al. (2021b) has recently shown that using a combination of both LSH-based and random features based methods works better at approximating the attention matrix A. See (Tay et al., 2022) for a survey. 3 KDEformer: Accelerating Transformers via Kernel Density Estimation 3.1. High-level Architecture of the Algorithm Algorithm 1 KDEformer Here, we assume that we have access to an oracle, which can estimate the weighted linear combination of n exponential kernels at arbitrary query points, and given this oracle, we e which satisfy design an algorithm that can output Π and D Eq. (1). In other words, we translate and reduce the problem of spectrally approximating Att(Q, K, V) to a weighted KDE problem corresponding to the exponential dot-product kernel. The precise interface and desired properties of this oracle are presented in the following definition, 1: input: matrices Q, K, V ∈ Rn×d , integer m, and ε > 0 −2 2: γ ← ∥V∥op via power method  Q ε , , 1 , n 3 in Definition 3.1 d1/4 d1/4 √  √ 2·Q 2·K 4: β ← WE XP KDE d1/4 , d1/4 , u, 1/3 , 3: α ← WE XP KDE X i∈[n] vi exp(⟨xi , yj ⟩) ∀j ∈ [n]. j∈[n] 6: generate i.i.d. samples ℓ1 , ℓ2 , . . . ℓm ∈ [n] from distri- bution {pℓ }ℓ∈[n] 1 7: let r th row of Π be √m·p More formally, we have the following result which is a slight modification of Theorem 2.1 from (Zouzias, 2013) and is proved in Appendix B.1. Lemma 3.2 (AMM). For any matrices X ∈ Rn×q , Y ∈ Rn×d and any probability distribution {pi }i∈[n] satisfying pi ≥ (4) 2   Q K Therefore, if we let α = WE XP KDE d1/4 , d1/4 , 1n , 3ε e = diag(α), then by Definition 3.1 and using and define D the fact that entries of D are positive, we have (1−ε/3)D ⪯ e ⪯ (1 + ε/3)D where ⪯ is the Loewner order. So, using D the fact that D−1 A op = 1, op ε · ∥V∥op . 2 Pr Generating the Sampling Matrix Π. Given a diagonal e which satisfies Eq. (5), by triangle inequality, in matrix D order to satisfy the spectral bound of Eq. (1), it suffices to find a sampling matrix for which the following holds, op ≤ ε · ∥V∥op 2 2 i 2 i 2 · ∥X∥ for all i ∈ [n] and γ = 2 +γ·∥Y∥2 F F h i X⊤ Π⊤ ΠY − X⊤ Y op > ε ∥X∥op ∥Y∥op ≤ 1 . poly(n) e −1 A and So, by invoking Lemma 3.2 with X⊤ = D Y = V and error parameter ε/2, we can find a random sampling matrix Π which satisfies Eq. (6) with high probability inn, as long as the number of samples isat least e −1 A) + srank(V)) . The m = Ω ε−2 log n(srank(D only catch is that, to apply Lemma 3.2, we need to compute the distribution {pi }i∈[n] as per this lemma. In other words, we need to compute the row norms of V as well as the cole −1 A. All row norms of V can be computed umn norms of D in O(nd) time. However, naively computing the column e −1 A would require Θ(n2 d) operations. Fortunorms of D e −1 A can be approximated via nately, the column norms of D the primitive WE XP KDE from Definition 3.1. (5) Hence, we can precision by invoking   estimate D to sufficient Q K WE XP KDE d1/4 , d1/4 , 1n , 3ε . e −1 AΠ⊤ · ΠV − D e −1 AV D ∥x ∥2 +γ·∥y ∥2 1 4 ∥X∥op / ∥Y∥op , a sampling matrix Π ∈ Rm×n constructed by first generating m i.i.d. samples ℓ1 , . . . ℓm ∈ [n] accord1 ing to {pℓ }ℓ∈[n] and then letting the rth row of Π be √m·p ·  ℓr −2 ⊤ eℓr , if m = Ω ε log n · (srank(X) + srank(Y)) for some ε > 0, the following holds, i∈[n] ≤ · e⊤ ℓr for every r ∈ [m] Approximate Matrix Multiplication (AMM) with respect to the spectral norm. It is known how to achieve the above guarantee using a sampling matrix with m = O ε−2 log n · (srank(D−1 A) + srank(V)) i.i.d. rows. (3)   √   Estimating D = diag exp QK⊤ / d 1n . One can e −1 AV Att(Q, K, V) − D ℓr e = diag(α) and Π 8: return D e that satisfy Eq. (1), Now we show how to generate Π and D given access to WE XP KDE as per Definition 3.1. easily see that the j th diagonal entry of D equals:  X √  Dj,j = exp ⟨ki , qj ⟩/ d ∀j ∈ [n]. K where ui ← 1/αi2 for every i ∈ [n] 2 5: pi ← βi + γ · ∥vi ∥2 for every i ∈ [n] then normalize pℓ ← P pℓ pj for every ℓ ∈ [n] Definition 3.1 (Weighted Exponential KDE). Let X, Y ∈ Rn×d be arbitrary datasets and let v ∈ Rn+ be an arbitrary vector with positive coordinates. For any ε > 0, primitive WE XP KDE(X, Y, v, ε) outputs a non-negative vector α ∈ Rn+ such that: αj ∈ (1 ± ε) ·  (6) So, our goal is to design a sampling matrix Π ∈ Rm×n with a small number m of rows that satisfies Eq. (6). This problem is in fact well studied in the randomized numerical linear algebra literature and is known as the e and sampler Π is presented The procedure for computing D in Algorithm 1. We state the correctness of Algorithm 1 in the following theorem and prove it in Appendix B.2. 4 KDEformer: Accelerating Transformers via Kernel Density Estimation Algorithm 2 Weighted Exponential KDE (WE XP KDE) Theorem 3.3 (Correctness of Algorithm 1). For any matrices Q, K, V ∈ Rn×d , any ε > 0, and number of samples  m = Ω ε−2 log n · (srank(D−1 A) + srank(V)) , given access to a primitive WE XP KDE as per Definition 3.1, e ∈ Rn×n and a Algorithm 1 outputs a diagonal matrix D m×n sampling matrix Π ∈ R which satisfy Eq. (1) with 1 probability at least 1 − poly(n) . 1: input: matrices X, Y ∈ Rn×d , vector v ∈ Rn + , error parameter ε > 0, and τ > 0 2: µ ← 1/n and S ← [n] and α ← 0n ∥xj ∥22 P 2 j∈[n] vj e q N 4: wi ← 2 log v ·exp(∥x for every i ∈ [n] ∥2 /2) 3: N ← i So, to spectrally approximate Att(Q, K, V), it is enough to run Algorithm 1. This algorithm relies on the existence of primitive WE XP KDE as per Definition 3.1, therefore, we focus on efficient implementation of WE XP KDE. Here, we devise an efficient algorithm that satisfies the desired properties of WE XP KDE as per Definition 3.1. We show that this procedure is tightly related to and can be translated to an instance of the Gaussian KDE. First note that if all data-points in dataset X were on a sphere, i.e., ∥xi ∥2 = r for all i ∈ [n] and some r > 0, then the weighted exponential kernel density corresponding to the weights 2 2 v = n1 · 1n would be equal to e(∥q∥2 +r )/2 · µX (q), where µX (q) is defined as in Eq. (2). µX′ (q ′ ), and we show how to adaptively learn µ e in Algorithm 2 using the fact that if Q UERY KDE(DSkde , q ′ ) outputs zero we can infer that our lower bound was too high. We analyze Algorithm 2 in the following theorem. Theorem 3.4 (Analysis of Algorithm 2). For every matrices X, Y ∈ Rn×d , any non-negative vector v ∈ Rn+ , and any ε ∈ (0, 1), and given a fast Gaussian KDE as per Theorem 2.1, Algorithm 2 outputs a vector α ∈ Rn which satisfies the desired conditions of Definition 3.1 (i.e., Eq. (3)). Furthermore, this procedure’s runtime is O (nd · CX,Y,v,ε,τ ), where Our proposed WE XP KDE primitive employs a fast Gaussian KDE method as per Theorem 2.1. The weighted exponential kernel density for a query point q and weight vector v ∈ Rn+ can be written as, ∥q∥2 2 vi e⟨xi ,q⟩ = e 2 X vi e i∈[n] i∈[n] Let us define wi := r ∥xi ∥22 2 · e− ∥xi −q∥22 2 2 j∈[n] vj exp(∥xj ∥2 /2) vi ·exp(∥xi ∥22 /2) . (7) CX,Y,v,ε,τ := (9)   ⟨xj ,yi ⟩  1 j=1 vj e < nµ min 2 τ + i ∈ [n] : ∥xj ∥22 +∥yi ∥22 Pn µ>0 ε µ     2 v e P 2 log for every i ∈ [n] and define the augmented dataset X′ ∈ Rn×(d+1) as x′i := xi ⊕ [wi ] for every i ∈ [n]. Also let the augmented query point be q ′ := q ⊕ [0]. Then, the r.h.s. in Eq. (7) can be written as ! 2 X ∥xi ∥22 ∥q∥2 ∥x′i − q ′ ∥2 wi2 2 2 2 vi e · exp − + e 2 2 ∥q∥2 2 X j∈[n] vj e ∥xj ∥22 2 · µX′ (q ′ ). (8) Therefore, the weighted exponential kernel density can be obtained from the Gaussian kernel density corresponding to the augmented dataset X′ and augmented query q ′ , i.e., µX′ (q ′ ). The augmented dataset can be constructed very efficiently in time O(nd), so given a fast Gaussian KDE as per Theorem 2.1, Eq. (8) shows us an efficient way to implement the WE XP KDE procedure. Our proposed procedure is presented in Algorithm 2. Note that, fast Gaussian KDE requires a lower bound µ e on the kernel density value    Pn j=1 j Proof. First, we prove the correctness. Let us index the iterations of the algorithm’s while loop by t = 0, 1, 2, . . . and let µt , αt , and St denote the value of µ, the vector α, and 1 set S at tth iteration. We have |St | ≤ n and µt = n·2 t for every t, thus, the algorithm must terminate in T = O(log n) iterations. Also, by Theorem 2.1, the set St+1 computed in line 9 equals St+1 = {i ∈ [n] : µX′ (yi′ ) < µt }, because the fast Gaussian KDE procedure outputs zero if and only if µX′ (yi′ ) < µt . i∈[n] =n·e 2 ∥yi ∥22 αi ← n · N · e 2 · Q UERY KDE(DSkde , yi′ ) for every i ∈ S 9: µ ← µ/2 and S ← {i ∈ [n] : αi = 0} 10: end while P 11: αj ← i∈[n] vi · exp(⟨xi , yj ⟩) for every j ∈ S 12: return α 8: 3.2. Weighted Exponential KDE X i 2 5: X′ ← [X; w] ∈ Rn×(d+1) , Y′ ← [Y; 0n ] ∈ Rn×(d+1) 6: while µ−τ ≤ ε2 · |S| do 7: DSkde ← P REPROCESS KDE(X′ , ε, µ) Next, we show by induction that at every iteration t, αt (i) ∥yi ∥22 is within (1 ± ε) factor of nN e 2 · µX′ (yi′ ) for all i ∈ [n] \ St . Base of induction is trivial because S0 = [n]. For proving the inductive step, note that in lines 7-8 αt+1 (i) is updated for every i ∈ St by invoking the fast Gaussian KDE procedure and αt+1 (i) = αt (i) for i ∈ [n] \ St . Thus, by the inductive hypothesis and Theorem 2.1 as well as 5 KDEformer: Accelerating Transformers via Kernel Density Estimation definition of St+1 in line 9, αt+1 (i) is within (1 ± ε) factor 3.3. Main Result of nN e 2 the inductive proof. Using the definition of N in line 3 and definition of X′ , Y′ in line 5 along with Eq. (8), the invariant that we proved implies thatPfor every t = 0, 1, . . . T , αt (i) is within (1 ± ε) factor of j∈[n] vj · exp(⟨xj , yi ⟩) for all i ∈ [n] \ St . After exiting the while Ploop, α(i) is updated at all i ∈ ST +1 in line 2 as α(i) = j∈[n] vj · exp(⟨xj , yi ⟩), and α(i) = αT (i) for every i ∈ [n] \ ST . This proves that the output vector α satisfies Eq. (3), which completes the correctness proof. Now we are in a position to prove our main result, i.e., an efficient algorithm that can approximate the attention mechanism with spectral guarantees as per Eq. (1). ∥yi ∥22 ·µX′ (yi′ ) for all i ∈ [n]\St+1 , which completes Theorem 3.5 (Approximate Attention with Spectral Norm Bound). For any matrices Q, K, V ∈ Rn×d , any ε > 0, and given a fast Gaussian KDE as per Theorem 2.1, there exists an algorithm that outputs a diagonal mae ∈ Rn×n and a sampling matrix Π ∈ Rm×n trix D  with m = O ε−2 log n · (srank(D−1 A) + srank(V)) samples which satisfy Eq. (1) with probability at 1 least 1 − poly(n) . The runtime of this algorithm    √ √ is O m + nd · C K , Q ,1n ,ε,τ + C 2·Q , 2·K ,v,1,τ , 1/4 1/4 d1/4 d1/4 P  d−2 d √1 ⟨qj , kℓ ⟩ where vj = for j ∈ [n] and ℓ∈[n] exp d √ C K , Q ,1n ,ε,τ , C 2·Q , √2·K ,v,1,τ are defined as in Eq. (9). Runtime Analysis. The runtime has three components; 1) Time to run P REPROCESS KDE in line 7. The total time of running this primitive in all iterations t = 0, 1, . . . T is PT d·n −τ  1 O , by Theorem 2.1. Since µt = n·2 t , this t=0 ε2 µt  d·n −τ runtime is bounded by O ε2 µT . d1/4 d1/4 d1/4 d1/4 We prove this theorem in Appendix B.3. The runtime bound in Theorem 3.5 can be simplified for datasets Q, K with bounded diameter as follows, 2) Time to run Q UERY KDE in line 8. By Theorem 2.1, the total time to run this procedure in  all iterations is PT P −τ d ′ O ε2 · t=0 i∈St (µt + µX′ (yi )) . Because |St | ≤ n, this runtime complexity is completely dominated by (1). Corollary 3.6 (Simplified Runtime for Bounded Diameter Datasets). For any datasets Q, K with diame√ 2 ter maxi,j∈[n] ∥ki − qj ∥2 = γ d log n for some γ > 0, the runtime of Theorem 3.5 is upper bounded  by O m + nd · nτ (1+γ) + ε−2 nτ (1+γ/2) , which is strongly sub-quadratic in n. In particular, if γ = o(1), the runtime is bounded by O m + ε−2 d · n1+τ +o(1) . 3) Time to exactly compute the weighted exponential densities of the points with very small µX′ (yi′ ) value in line 10. This runtime is bounded by O(nd · |ST +1 |). Now we combine these bounds. Using the assumption that the algorithm terminated at iteration t = T , the while loop condition at iteration T + 1 must fail. Therefore, |ST +1 | < −τ 2 2 µ−τ T +1 /ε < 2µT /ε . This shows that the first component of the runtime must dominate the third  component. So the −τ total time is bounded by O d·n . ε2 µT We prove Corollary 3.6 in Appendix B.4. The current best value for τ is τ = 0.173+o(1) due to Charikar et al. (2020), thus, for any datasets of queries√Q and keys K with diam2 eter maxi,j∈[n] ∥ki − qj ∥2 = o( d log n), our algorithm’s  runtime is O m + ε−2 d · n1.173+o(1) . Recall that the while loop terminates at iteration T meaning that ε−2 µ−τ ≤ |St | for every t = 0, 1, . . . T and t ε−2 µ−τ > |S |. So, T is the largest integer that satT +1 T +1 isfies ε−2 µ−τ ≤ |S T |. Also recall that St = {i ∈ [n] : T 1 µX′ (yi′ ) < µt−1 } and µt = n·2 t . Thus, the runtime of the procedure can be expressed as, 3.4. Practical Improvements by Exploiting Sparsity Our method relies on a sampling-based AMM (Lemma 3.2) and the number of samples m is proportional to srank(D−1 A) by Theorem 3.5. Here, we propose a practical technique for reducing the stable rank of D−1 A by finding and subtracting off its “heavy” elements. Specifically, ∥D−1 A∥2 recall that srank(D−1 A) = ∥D−1 A∥2F and the softmax O(nd) · min ε−2 µ−τ + |{i ∈ [n] : µX′ (yi′ ) < µ}| . µ>0 The definition of X′ , Y′ in line 5 along with Eq. (8) gives the claimed runtime bound in Eq. (9). op matrix D−1 A is dominated by its largest elements which correspond to the nearest pairs of queries qi and keys kj . Therefore, subtracting off the heavy elements of D−1 A re2 duces D−1 A F which in turn can reduce srank(D−1 A). To get a better understanding of the runtime bound in Theorem 3.4,  suppose that datasets X, Y are such that cardinality  P   v exp(⟨x ,y ⟩) j j i of set i ∈ [n] : P j∈[n]  ∥x ∥2 +∥y ∥2  ≤ n−o(1) is j 2 i 2   j∈[n] vj exp 2  upper bounded by O ε−2 · nτ . For such datasets, the run time of Theorem 3.4 is bounded by O ε−2 d · n1+τ +o(1) , which is strongly sub-quadratic in n. Similar to Reformer (Kitaev et al., 2020), we employ a Locality Sensitive Hashing (LSH) scheme to find dominant entries of the attention matrix A. Specifically, let H : Rd → [B] be an LSH function with B buckets such that the collision probability Pr[H(qi ) = H(kj )] is “roughly” 6 1.0 σi D−1A)  srank D−1A  σi D−1Ares 0.5  srank D−1Ares  0.0 0 singular values (σi) singular values (σi) KDEformer: Accelerating Transformers via Kernel Density Estimation 512 1024 1536 2048 index (i) (a) GloVe dataset 1.0 σi D−1A)  srank D−1A  σi D−1Ares 0.5  srank D−1Ares ⇒  0.0 0 D−1 A −1 (b) T2T-ViT on ImageNet proportional to ⟨qi , kj ⟩. Given such LSH function, we define the sparse approximation to A as well as the residual attention matrix as: ∀i, j ∈ [n] : [Aspar ]i,j := e Algorithm 3 Practical Improvement of KDEformer 1: input: matrices Q, K, V ∈ Rn×d , integer m, ε > 0, · 1{H(qi )=H(kj )} Ares := A − Aspar . and LSH function H : Rd → [B] 2: compute α, β, γ as per lines 2-4 of Algorithm 1 2⟨qi ,kj ⟩ Pn √ −2 d 1{H(qi )=H(kj )} + 3: pj ← βj − i=1 αj e 2 γ ∥vj ∥2 for every j ∈ [n] then normalize pℓ ← P pℓ for every ℓ ∈ [n] j∈[n] pj 4: generate the sampling matrix Πres as per lines 6-7 of Algorithm 1 using distribution {pj }j∈[n] computed above e = diag(α) and Πres 5: return D (10) Intuitively, the stable rank of D−1 Ares is expected to be smaller than that of D−1 A because the former has a considerably smaller Frobenius norm. We verify this intuition by plotting the singular values distributions of the softmax matrix D−1 A and the residual D−1 Ares for two real-world instances in Fig. 2. Fig. 2(a) corresponds to when keys and queries are the first n = 2,048 vectors from GloVe word embedding dataset (Pennington et al., 2014). In Fig. 2(b), we focused on the first attention layer in Tokens-to-token Vision Transformer (T2T-ViT) (Yuan et al., 2021) and an arbitrary batch of images from ImageNet dataset. In both instances, the singular values of the residual D−1 Ares decay faster than that of D−1 A while the largest singular value (spectral norm) of both matrices are equal to one. Thus, as shown in Fig. 2, subtracting off the sparse component D−1 Aspar reduces the stable rank significantly. in total nnz(Aspar ) time, the AMM sampling matrix for residual Πres can be generated quickly. Putting everything together, we first choose an appropriate LSH function H and compute the sparse approximation to the attention matrix as per Eq. (10). We show how to design a GPU-friendly LSH whose collision probability Pr[H(qi ) = H(kj )] is roughly proportional to ⟨qi , kj ⟩ in e for D, Appendix A. Next, we compute a spectral proxy D as was done efficiently in Algorithm 1. Finally, we perform e −1 Ares and V via a sampling matrix AMM on matrices D Πres . The resulting estimator is: Building upon this observation, we propose a new version of Algorithm 1 with improved practical performance. We start by using Eq. (10) to write: Att(Q, K, V) = D−1 Aspar V + D−1 Ares V. D−1 Aspar D−1 Ares Π⊤ res Figure 3. The softmax matrix D A decomposes into its sparse approximation D−1 Aspar , which captures large entries (coded with darker colors), and the residual D−1 Ares , where black cells represent entries captured by D−1 Aspar . Blank colors in the matrix on the right side represent columns of D−1 Ares not sampled by AMM sampling matrix Πres and the union of colored and blank columns together represent the whole residual matrix D−1 Ares . 784 1568 2352 3136 index (i) Figure 2. Singular values distribution and stable rank of the softmax matrix D−1 A versus those of the residual D−1 Ares . The stable rank of the residual matrix is significantly smaller. ⟨qi ,kj ⟩ √ d + (11) g=D e −1 Aspar V + D e −1 Ares Π⊤ · Πres V. Att res Given D, the first term above can be computed in time O(d · nnz(Aspar )), where nnz(·) denotes the number of nonzero entries of a matrix. By choosing an appropriate LSH we can ensure that nnz(Aspar ) is almost linear in n. We illustrate this procedure in Fig. 3 and present the pseue and Πres in Algorithm 3. By an docode for computing D analysis similar to Corollary 3.6, we find that the runtime of Algorithm 3 is O(m + ε−2 dn1+τ +o(1) + nnz(A  spar )) with some m = O ε−2 log n · srank(D−1 Ares ) . The second term in Eq. (11) can be approximated via AMM, similar to what was done in Algorithm 1, however, we need to be able to estimate the column norms of D−1 Ares . Fortu2 2 nately, by Eq. (10), we have D−1 Ajres 2 = D−1 Aj 2 − 4. Experiments 4.1. Single Self-attention Layer Approximation 2 D−1 Ajsparse 2 , where Ajres , Aj , Ajsparse denote the j th columns of Ares , A, Aspar , respectively. Since we can estimate the column norms of D−1 A efficiently using WE XP KDE and all column norms of D−1 Aspar can be computed We first benchmark our algorithm on approximating a single self-attention layer, i.e., Att(Q, K, V). We randomly select a pair of matrices Q, V ∈ Rn×d from the GloVe word 7 0.2 0.0 2 8 32 128 512 GFLOPS 0.4 0.2 0.0 2 8 32 memory usage (GiB) 103 memory usage (GiB) 0.4 CPU-clock time (sec) relative spectral error relative spectral error KDEformer: Accelerating Transformers via Kernel Density Estimation 101 64 256 1024 4096 feature dimension Reformer Performer ScatterBrain KDEformer Exact 101 100 256 1024 4096 16384 sequence length n Figure 4. Performance evaluations of various self-attention approximations on approximating under the GloVe word embeddings. embeddings (Pennington et al., 2014) with sequence length n = 8,192 and dimension d = 100 and set K = Q. We compare our KDEformer to other attention approximations including Reformer (Kitaev et al., 2020), Performer (Choromanski et al., 2021), and ScatterBrain (Chen et al., 2021b). We compute the relative error with respect to the operag op ∥Att(Q,K,V)−Att∥ g ∈ Rn×d is tor norm, i.e., where Att Table 1. Results on image generation using BigGAN with the exact attention and its approximations. Bold values indicate the best within the standard deviation. ∥Att(Q,K,V)∥op an approximate attention. Additionally, we measure the peak memory usage, FLOP count, and CPU-clock time of approximation methods while varying hyperparameters of algorithms that affect runtime and memory. Method FID (↓) IS (↑) Exact Reformer Performer ScatterBrain KDEformer 32.17 72.39 33.39 38.55 31.41 58.38 ± 4.23 19.04 ± 2.32 37.32 ± 2.91 36.43 ± 3.34 58.16 ± 4.04 GFLOPS 10.738 10.872 1.682 2.891 2.596 − (0.99×) (6.38×) (3.71×) (4.14×) Table 2. Results on ImageNet classification using T2T-ViT with the exact attention and its approximations. In Fig. 4, we observe that our proposed algorithm achieves the lowest error with minimal FLOP count and memory usage. In particular, our approximation error can be about 9% with 3.06× memory reduction and 5.11× FLOPS reduction. In addition, we plot CPU-clock time for various choices of hyperparameters that determine peak memory usage. Specifically, if the approximation method requires g we refer to k at most nk memory usage for computing Att, as the feature dimension. Having equal feature dimensions, our algorithm and Performer are the fastest methods, but Performer has significantly larger errors than the others. We fix the feature dimension k = 128 and measure the peak memory usage while the sequence length n is changing from 256 and 16,384. For n = 16,384, our method can save up to 19.62× memory space compared to the exact computation. Method Exact Reformer Performer ScatterBrain KDEformer Top-1 Accuracy (%) 82.55 81.44 80.50 81.95 82.08 GFLOPS 161.10 11.71 5.06 7.18 8.80 − (13.75 ×) (31.87 ×) (22.43 ×) (18.30 ×) FLOPS in the attention layer. We set the hyperparameters (i.e., feature dimensions) so that all approximation methods have the same peak memory usage. The results are reported in Table 1. Interestingly, our algorithm shows a lower FID value than the exact attention with 4.14× fewer FLOPs. Although Performer is the fastest algorithm, its generated images look unnatural while our attention can generate more realistic images. A number of generated images by various methods can be found in Appendix C.1. 4.2. Image generation with pre-trained BigGAN We next apply the above-mentioned attention approximations to the pre-trained BigGAN model (Brock et al., 2019) to generate synthetic images. The model contains a single attention layer where the corresponding inputs have different dimensions: Q ∈ R4,096×64 , K ∈ R1,024×64 and V ∈ R1,024×256 . Following the experiments in (Chen et al., 2021b), we use the pre-trained BigGAN on ImageNet at 512 × 512 resolution and replace the attention layer with its approximations. Note that we do not perform training or fine-tuning at all. We generate 5,000 fake images and compute the Frechet Inception Distance (FID) with ImageNet validation set as ground truth and Inception Scores (IS) (Salimans et al., 2016). Lower FID and higher IS values imply better generation quality. We also measure the number of 4.3. ImageNet classification with Vision Transformer Next, we evaluate the attention approximations in the context of image classification with pre-trained Tokens-toToken Vision Transformer (T2T-ViT) (Yuan et al., 2021). The model consists of Tokens-to-Token (T2T) module and the Vision Transformer (ViT) backbone. The T2T module consists of 2 attention layers with sequence lengths 3136 and 784, respectively, and the ViT module consists of 24 attention layers all of which have token lengths 197 (in addition, the embedding dimension of attention layers in the ViT is 512, larger than the sequence length). Hence, the two attention layers in T2T module dominate the overall computational and memory complexities because their sequence lengths are significantly larger. So, we exactly compute the 8 KDEformer: Accelerating Transformers via Kernel Density Estimation Table 3. Results on end-to-end training on 5 Long Range Arena (LRA) benchmark datasets. Exact Reformer Performer KDEformer ListOps Text Image Retrieval Pathfinder Average 33.32 36.74 37.75 36.64 60.22 61.39 58.81 62.00 37.41 43.59 35.74 45.45 81.07 78.15 80.39 73.52 70.25 66.25 62.84 68.13 56.45 57.22 55.11 57.15 (a) Test accuracy (%) Exact Reformer Performer KDEformer ListOps Text Image Retrieval Pathfinder Average 6.53 1.59 1.07 1.02 16.71 3.18 2.13 2.03 9.41 6.36 4.28 4.08 8.72 2.94 2.15 2.38 4.70 3.18 2.14 1.87 9.21 3.45 2.35 2.28 (b) Peak memory (GB) Exact Reformer Performer KDEformer ListOps Text Image Retrieval Pathfinder Average 0.133 0.041 0.036 0.034 0.479 0.081 0.067 0.058 0.276 0.155 0.127 0.110 0.478 0.092 0.074 0.073 0.141 0.082 0.068 0.063 0.301 0.090 0.074 0.068 (c) Wall-clock time (sec) per batch pre-trained model with 24 layers in ViT backbone and only apply the approximation methods to the 2 attention layers in the T2T module as a drop-in replacement. The dimensions of Q, K, V are all the same, n = 3,136, d = 64 in the first layer and n = 784, d = 64 in the second layer. steps. Similar to Section 4.1, we choose hyperparameters of all methods having equal feature dimensions to 128. In Table 3, we provide results on (a) test accuracy, (b) peak memory and (c) wall-clock time per batch of single training step (including forward and backward propagations). As a result, we observe that the proposed KDEformer achieves the second-best test accuracy in average followed by Reformer, but it requires much less memory as well as faster wall-clock time than other competitors. For example, KDEformer with Text dataset runs about 8× faster than the exact attention. We compute top-1 accuracy on the ImageNet validation dataset and measure FLOPS in the first attention layer, which requires the most resources. The results are shown in Table 2. Observe that our method is the best among all approximate methods with 82.08% test accuracy. In particular, it leads to less than 1% performance drop compared to the exact computation but the required operations are 18.3× fewer. Such performance gains would increase when token sequence lengths are larger. 5. Conclusion We propose a fast attention approximation based on recent advances in KDE solvers. The proposed algorithm can run in strongly sub-quadratic time in sequence length and provide an error bound under the spectral norm. It shows promising performances under various practical applications involving long-sequence attention. We believe this can have a significant impact on other practical problems as well. 4.4. End-to-end Training with Long Range Arena Benckmark Finally, to demonstrate the power of our method in reducing the training time of transformer models, we run end-to-end training on the Long Range Arena benchmark (Tay et al., 2021), which contains 5 classification datasets, i.e., ListOps, Text, Image, Retrieval and Pathfinder. The maximum sequence lengths of these datasets are 2,048, 4,096, 1,024, 4,096 and 1,024, respectively. We follow the same settings from (Chen et al., 2021c); model is a 2-layer transformer with 64 embedding dimension, 128 hidden dimension, 2 attention heads, and mean pooling is used for the classification task. Learning rate is set to 10−4 for Text, ListOps, Image and 2 × 10−4 for the rest. All models are trained for 50,000 Acknowledgements Insu Han and Amin Karbasi acknowledge funding in direct support of this work from NSF (IIS-1845032), ONR (N00014- 19-1-2406), and the AI Institute for LearningEnabled Optimization at Scale (TILOS). 9 KDEformer: Accelerating Transformers via Kernel Density Estimation References Choromanski, K. M., Likhosherstov, V., Dohan, D., Song, X., Gane, A., Sarlos, T., Hawkins, P., Davis, J. Q., Mohiuddin, A., Kaiser, L., et al. Rethinking Attention with Performers. In International Conference on Learning Representations (ICLR), 2021. 1, 3, 8, 18 Backurs, A., Charikar, M., Indyk, P., and Siminelakis, P. Efficient density evaluation for smooth kernels. In Foundations of Computer Science (FOCS), 2018. 2 Backurs, A., Indyk, P., and Wagner, T. Space and time efficient kernel density estimation in high dimensions. Neural Information Processing Systems (NeurIPS), 2019. 2 Daras, G., Kitaev, N., Odena, A., and Dimakis, A. G. Smyrfefficient attention using asymmetric clustering. Neural Information Processing Systems (NeurIPS), 2020. 1, 3 Dass, J., Wu, S., Shi, H., Li, C., Ye, Z., Wang, Z., and Lin, Y. Vitality: Unifying low-rank and sparse approximation for vision transformer acceleration with a linear taylor attention. arXiv preprint arXiv:2211.05109, 2022. 1 Brock, A., Donahue, J., and Simonyan, K. Large Scale GAN Training for High Fidelity Natural Image Synthesis. In International Conference on Learning Representations (ICLR), 2019. 2, 8 Devlin, J., Chang, M.-W., Lee, K., and Toutanova, K. Bert: Pre-training of deep bidirectional transformers for language understanding. In Conference of the North American Association for Computational Linguistics (NAACL), 2018. 1 Brown, T., Mann, B., Ryder, N., Subbiah, M., Kaplan, J. D., Dhariwal, P., Neelakantan, A., Shyam, P., Sastry, G., Askell, A., et al. Language models are few-shot learners. Neural Information Processing Systems (NeurIPS), 2020. 1 Dosovitskiy, A., Beyer, L., Kolesnikov, A., Weissenborn, D., Zhai, X., Unterthiner, T., Dehghani, M., Minderer, M., Heigold, G., Gelly, S., Uszkoreit, J., and Houlsby, N. An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale. In International Conference on Learning Representations (ICLR), 2021. 1 Carion, N., Massa, F., Synnaeve, G., Usunier, N., Kirillov, A., and Zagoruyko, S. End-to-end object detection with transformers. In Proceedings of the European Conference on Computer Vision(ECCV), 2020. 1 Charikar, M. and Siminelakis, P. Hashing-based-estimators for kernel density in high dimensions. In Foundations of Computer Science (FOCS), 2017. 2 Hagerup, T., Mehlhorn, K., and Munro, J. I. Maintaining discrete probability distributions optimally. In International Colloquium on Automata, Languages, and Programming, 1993. 17 Charikar, M., Kapralov, M., Nouri, N., and Siminelakis, P. Kernel density estimation through density constrained near neighbor search. In Foundations of Computer Science (FOCS), 2020. 2, 3, 6 Joshi, S., Kommaraji, R. V., Phillips, J. M., and Venkatasubramanian, S. Comparing distributions and shapes using the kernel distance. In Symposium on Computational Geometry (SOCG), 2011. 2 Chen, B., Liu, Z., Peng, B., Xu, Z., Li, J. L., Dao, T., Song, Z., Shrivastava, A., and Re, C. MONGOOSE: A learnable LSH framework for efficient neural network training. In International Conference on Learning Representations (ICLR), 2020. 3 Katharopoulos, A., Vyas, A., Pappas, N., and Fleuret, F. Transformers are rnns: Fast autoregressive transformers with linear attention. In International Conference on Machine Learning (ICML), 2020. 1, 3 Chen, B., Dao, T., Liang, K., Yang, J., Song, Z., Rudra, A., and Re, C. Pixelated Butterfly: Simple and Efficient Sparse training for Neural Network Models. In International Conference on Learning Representations (ICLR), 2021a. 1 Kitaev, N., Kaiser, L., and Levskaya, A. Reformer: The Efficient Transformer. In International Conference on Learning Representations (ICLR), 2020. 1, 3, 6, 8, 12, 18 Pennington, J., Socher, R., and Manning, C. D. Glove: Global vectors for word representation. In Empirical Methods in Natural Language Processing (EMNLP), 2014. 7, 8 Chen, B., Dao, T., Winsor, E., Song, Z., Rudra, A., and Re, C. Scatterbrain: Unifying sparse and low-rank attention. Neural Information Processing Systems (NeurIPS), 2021b. 1, 3, 8, 18 Raffel, C., Shazeer, N., Roberts, A., Lee, K., Narang, S., Matena, M., Zhou, Y., Li, W., and Liu, P. J. Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer. Journal of Machine Learning Research (JMLR), 2020. 1 Chen, Y., Zeng, Q., Ji, H., and Yang, Y. Skyformer: Remodel self-attention with gaussian kernel and nystr\" om method. Neural Information Processing Systems (NeurIPS), 2021c. 9 10 KDEformer: Accelerating Transformers via Kernel Density Estimation Roy, A., Saffar, M., Vaswani, A., and Grangier, D. Efficient content-based sparse attention with routing transformers. Transactions of the Association for Computational Linguistics (ACL), 2021. 1, 3 Zhou, H., Zhang, S., Peng, J., Zhang, S., Li, J., Xiong, H., and Zhang, W. Informer: Beyond efficient transformer for long sequence time-series forecasting. In Conference on Artificial Intelligence (AAAI), 2021. 1 Salimans, T., Goodfellow, I., Zaremba, W., Cheung, V., Radford, A., and Chen, X. Improved techniques for training gans. In Neural Information Processing Systems (NeurIPS), 2016. 8 Zouzias, A. Randomized primitives for linear algebra and applications. University of Toronto, 2013. 2, 4 Schölkopf, B., Smola, A. J., Bach, F., et al. Learning with kernels: support vector machines, regularization, optimization, and beyond. MIT press, 2002. 2 Siminelakis, P., Rong, K., Bailis, P., Charikar, M., and Levis, P. Rehashing kernel evaluation in high dimensions. In International Conference on Machine Learning (ICML), 2019. 2 Sun, Z., Yang, Y., and Yoo, S. Sparse Attention with Learning to Hash. In International Conference on Learning Representations (ICLR), 2021. 1, 3 Tay, Y., Dehghani, M., Abnar, S., Shen, Y., Bahri, D., Pham, P., Rao, J., Yang, L., Ruder, S., and Metzler, D. Long range arena: A benchmark for efficient transformers. International Conference on Learning Representations (ICLR), 2021. 3, 9 Tay, Y., Dehghani, M., Bahri, D., and Metzler, D. Efficient transformers: A survey. ACM Computing Surveys, 2022. 3 Tropp, J. A. An introduction to matrix concentration inequalities. Foundations and Trends® in Machine Learning, 2015. 14 Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., Kaiser, L., and Polosukhin, I. Attention is all you need. Neural Information Processing Systems (NeurIPS), 2017. 1 Yang, Z., Dai, Z., Yang, Y., Carbonell, J., Salakhutdinov, R. R., and Le, Q. V. Xlnet: Generalized autoregressive pretraining for language understanding. Neural Information Processing Systems (NeurIPS), 2019. 1 Yuan, L., Chen, Y., Wang, T., Yu, W., Shi, Y., Jiang, Z.-H., Tay, F. E., Feng, J., and Yan, S. Tokens-to-token vit: Training vision transformers from scratch on imagenet. In International Conference on Computer Vision (ICCV), 2021. 3, 7, 8, 18 Zaheer, M., Guruganesh, G., Dubey, K. A., Ainslie, J., Alberti, C., Ontanon, S., Pham, P., Ravula, A., Wang, Q., Yang, L., et al. Big bird: Transformers for longer sequences. Neural Information Processing Systems (NeurIPS), 2020. 1 11 KDEformer: Accelerating Transformers via Kernel Density Estimation A. Practical Angular LSH with Fixed Bucket Sizes The practical version of our algorithm that we presented in Section 3.4 requires a locality sensitive hashing H : Rd → [B] for identifying the dominant entries of the attention matrix A, which correspond to pairs of keys and queries whose “angular distances” are small. In this section, we develop a simple yet effective and practical LSH function whose collision probability is related to the angular distance between hashed points. While the lsh allows computing a very sparse approximation to the attention matrix, uneven bucket sizes hinder batching of the computations across lsh buckets. In fact, if we parallelize the computation across buckets, the largest bucket determines the runtime (Kitaev et al., 2020). Our proposed lsh function has equal-sized buckets, thus, it aligns with modern hardware’s block-memory access and can be efficiently parallelized by batching across buckets. We start by defining a simple LSH function whose collision probability is roughly proportional to the angle between the hashed points. Definition A.1 (Angular LSH). For positive integers d, r, let w1 , w2 , . . . wr be i.i.d. random samples from the tropical Gaussian distribution N (0, Id ). We define the rank-r angular LSH h : Rd → {0, 1}r as follows:   h(x) := 1{w1⊤ x} , 1{w2⊤ x} , . . . 1{wr⊤ x} for any x ∈ Rd . Note that the buckets are labeled by r-bit binary numbers and if r ≤ d then almost surely the total number of buckets is 2r . It is easy to calculate the collision probability of the angular lsh defined in Definition A.1. Claim 1. For positive integers r, d let h(·) be an instance of rank-r angular LSH as per Definition A.1. For any x, y ∈ Rd the collision probability of h(x) and h(y) is:  r θx,y Pr[h(x) = h(y)] = 1 − , π  ⊤  x y where θx,y = cos−1 ∥x∥·∥y∥ denotes the angle between x and y. Therefore, the points with small angular distances are likely to be hashed to the same buckets while points with large angular distances are unlikely to be hashed to the same buckets. So, if we hash keys kj and queries qi using the angular lsh given in Definition A.1 then the entries of the attention matrix A which correspond to colliding pairs of keys and queries will likely have very large values. As we mentioned earlier, the main efficiency bottleneck in this lsh-based approach for computing the dominant entries of the attention matrix is the unevenness of hash bucket sizes. If we try to compute the sparse approximation to A, as defined in Eq. (10), using the lsh function from Definition A.1 by parallelizing the computation across buckets, the runtime will be dominated by the time to compute entries in the largest bucket. One solution for increasing efficiency, which was proposed in (Kitaev et al., 2020), is to truncate the lsh buckets and force them to contain equal number of keys and queries. However, truncation can degrade the quality of approximation drastically because there will be spillover from one bucket to another, and some points can be forced into far-away buckets. The reason for this spillover effect is the fact that consecutive buckets in a hash table do not necessarily represent areas of the Rd space which are geometrically close to each other. We show that in fact, it is possible to sort the buckets of the angular lsh from Definition A.1 such that the order of buckets reflects their geometrical position, thus, consecutive buckets actually represent neighboring partitions of Rd . It turns out that the geometric distance between two buckets of this lsh function translates into the Hamming distance between their binary labels. To be precise, for any binary numbers b1 , b2 ∈ {0, 1}r let dH (b1 , b2 ) ∈ [r + 1] represent the Hamming distance between the two, i.e., the number of bits where b1 and b2 differ. Now note that the lsh buckets in Definition A.1 are labeled with r-bit binary numbers. Each bit in the binary representations of buckets corresponds to a partitioning of the Rd into two sides of a random hyperplane whose normal vector is sampled from a tropical Gaussian. Therefore, if we have two buckets b1 and b2 with hamming distance dH (b1 , b2 ) = 1 then these buckets are positioned on the same sides of all random hyperplanes except for one, thus, they represent neighboring regions in Rd and the hyperplanes corresponding to the differing bit of b1 and b2 is the boundary between two regions. 12 KDEformer: Accelerating Transformers via Kernel Density Estimation y = −x 1 11 0.5 0 10 01 00 −0.5 −1 −1 −0.5 0 0.5 (a) Space partitions by angular LSH 1 1 1 0.5 0.5 0 0 −0.5 −0.5 −1 −1 −0.5 0 0.5 (b) Hashing an example dataset 1 −1 −1 y = −x −0.5 0 0.5 1 (c) Buckets truncation in Hamming distance order Figure 5. Rank-2 Angular LSH in action (in dimension d = 2). The space partitions corresponding to buckets with unit Hamming distance are neighbors in Rd . In Fig. 5(b) we hash an example dataset and we get uneven buckets. Fig. 5(c) show that if we order the dataset according to the Hamming distance of their buckets and then truncate the buckets we get new equal-sized buckets with minimal spillover effect. We show this fact in Fig. 5(a), which illustrates the space partitions corresponding to the buckets of a rank-2 angular lsh in dimension d = 2. It is clearly visible that the bucket labels of neighboring partitions have unit Hamming distance. In Fig. 5(b) we hash an example dataset using this LSH function and as can be seen, the buckets have uneven sizes. Because of the relationship between the Hamming distance of bucket labels and the distance between y = x/3space partitions, if we order the dataset according to the Hamming ordering of their buckets and then truncate them we get new buckets with even sizes and minimal spillover effect. In particular, in Fig. 5(c) we order the dataset such that the points from buckets 00, 01, 11, 10 come in this specific order and then we bin the data points by partitioning the ordered dataset into equal-sized parts. The resulting bins show no spillover effect. In the following lemma we show how to order r-bit binary numbers {0, 1}r such that all consecutive numbers have unit Hamming distance: Lemma A.2 (Ordering of binary numbers according to their Hamming distance). For any positive integer r it is possible to order the set of binary numbers {0, 1}r as a sequence b1 , b2 , . . . b2r such that for any j ∈ [2r − 1]: dH (bj , bj+1 ) = 1. Proof. The proof is by induction. For r = 1 the base of induction follows trivially. Now suppose that we have the sequence of (r − 1)-bit numbers b′1 , b′2 , . . . b′2r−1 such that dH (b′j , b′j+1 ) = 1 for any j ∈ [2r−1 − 1]. Then the sequence of r-bit numbers will be as follows: ( (b′j , 0) if j ≤ 2r−1 bj := for j ∈ [2r ]. (b′2r +1−j , 1) if j > 2r−1 One can verify that this sequence satisfies the desired property and the proof is complete. Therefore, we can use the angular LSH together with the ordering of binary numbers from Lemma A.2 to construct an effective hash function with equal-sized buckets. Definition A.3 (Equal-sized LSH with Minimal Spillover). Suppose that we want to hash a dataset x1 , x2 , . . . xn ∈ Rd . 1. Hash these points using a rank-r Angular LSH h(·) as per Definition A.1. 2. Then, using Lemma A.2, produce an ordering of r-bit binary numbers such that consecutive numbers have unit Hamming distance; let b1 , b2 , . . . b2r be such ordering. 3. Next, define a permutation P ∈ Sym(n) which orders the dataset according to the Hamming ordering of their buckets. More specifically, P satisfies: P(i) < P(j) iff h(xi ) ≤∗ h(xj ), where the inequality ≤∗ acts with respect to the ordering b1 , b2 , . . . b2r . 13 y = x/3 KDEformer: Accelerating Transformers via Kernel Density Estimation k1 k2 k3 k4 k5 k6 k7 k8 k9 k2 k8 k4 k7 k3 k6 k9 k1 k5 q1 q2 q3 q4 q5 q6 q7 q8 q9 q2 q8 q5 q4 q3 q6 q7 q9 q1 P ⇒ P −1 ⇒ A b1 b2 b3 b4 AP Aspar Figure 6. An example of how Aspar can be computed efficiently. (Left) keys and queries are hashed using the angular lsh function. buckets are represented by shades of violet. (Middle) keys and queries are permuted such that their buckets are sorted according to the Hamming distance ordering. Large entries of the permuted attention matrix AP are concentrated around the diagonal blocks, so we compute the diagonal blocks. (Right) the block diagonal approximation to AP is reverse permuted to obtain Aspar . 4. Permute x1 , x2 , . . . xn according to P and then partition the sequence into equal-sized chunks. These chunks are the buckets. Now we explain how we can use the lsh procedure given in Definition A.3 to compute Aspar as per Eq. (10) through an example shown in Fig. 6. We first hash keys kj and queries qi via the angular lsh. We represent the buckets of this hashing via different shades of violet in Fig. 6. Clearly, the bucket sizes are uneven. Then we permute keys and queries via P which orders the points such that their buckets are sorted according to the ordering b1 , b2 , b3 , b4 obtained from Lemma A.2. Then we truncate the sorted points which is in fact equivalent to selecting blocks along the diagonal of the permuted attention matrix. The selected diagonal blocks in Fig. 6 illustrate this. Finally, we can reverse the permutation on the rows and columns of the block diagonal attention which gives us the final Aspar . B. Omitted Proofs B.1. Proof of Lemma 3.2: Approximate Matrix Multiplication via Sampling In this section, we analyze the random sampling method for approximately computing the product of two rectangular matrices, presented in Lemma 3.2. The proof of this lemma is based on the following version of the matrix Bernstein inequality. Lemma B.1 (Matrix Approximation by Random Sampling, Corollary 6.2.1 from (Tropp, 2015)). Let B be a fixed q × d matrix. Construct a q × d random matrix R that satisfies E[R] = B, and ∥R∥op ≤ L. Compute the per-sample second moment: m2 (R) = max{∥E[R∗ R]∥op , ∥E[RR∗ ]∥op }. Form the matrix sampling estimator m Rm = 1 X Ri m i=1 where each Ri is an independent copy of R. Then for every t > 0, the estimator satisfies  h i Pr Rm − B op ≥ t ≤ (q + d) · exp 14 −mt2 /2 m2 (R) + 2Lt/3  . KDEformer: Accelerating Transformers via Kernel Density Estimation Now we prove Lemma 3.2 by invoking the above matrix Bernstein inequality. Lemma 3.2 (Approximate Matrix Multiplication (AMM)). For any matrices X ∈ Rn×q , Y ∈ Rn×d and any probability ∥xi ∥22 +γ·∥yi ∥22 2 2 distribution {pi }i∈[n] which satisfies pi ≥ 41 · ∥X∥ for all i ∈ [n] and γ = ∥X∥op / ∥Y∥op , a sampling matrix 2 +γ·∥Y∥2 F F Π ∈ Rm×n constructed by first generating m i.i.d. samples ℓ1 , ℓ2 , . . . ℓm ∈ [n] according to {pℓ }ℓ∈[n] and then letting the 1 −2 rth row of Π be √m·p · e⊤ log n · (srank(X) + srank(Y)) for some ε > 0, the following holds, ℓr , if m = Ω ε ℓ r Pr h i X⊤ Π⊤ ΠY − X⊤ Y op > ε ∥X∥op ∥Y∥op ≤ 1 . poly(n) Proof. First we let B := X⊤ Y. Then we let the random matrix R have the following distribution   x⊤ · yi Pr R = i = pi for i ∈ [n] pi where xi and yi are ith row vector in X and Y, respectively. With this definition we have, E[R] = X x⊤ · yi i pi i∈[n] · pi = X i∈[n] ⊤ x⊤ i · yi = X Y = B. Furthermore, we can bound the operator norm of R as follows, ∥R∥op ≤ max x⊤ i · yi op pi ∥xi ∥2 ∥yi ∥2 = max pi i∈[n] i∈[n] ≤ 4 · max   2 2 ∥xi ∥2 ∥yi ∥2 · ∥X∥F + γ · ∥Y∥F 2 2 ∥xi ∥2 + γ · ∥yi ∥2 1 √ 2 2 ≤ 2 · max √ · ∥X∥F + γ · ∥Y∥F γ i∈[n] i∈[n] = 2 ∥X∥op · ∥Y∥op · (srank(X) + srank(Y)) ≡ L, where the third line above follows from the precondition of Lemma 3.2 about the distribution {pi }i∈[n] and the fourth line follows from AM-GM inequality. The last line follows from the definition of γ and definition of stable rank. Next, we will compute the per-sample second moment as follows, E[R∗ R] = X i∈[n] 2 ∥xi ∥2 ·  ⊤ X yi⊤ · yi 2 yi · yi · p = ∥x ∥ · i i 2 p2i pi i∈[n]  X 2 2 ⪯ 4 · ∥X∥F + γ · ∥Y∥F · 2 ∥xi ∥2 ⊤ 2 2 · yi yi ∥x ∥ + γ · ∥y ∥ i i 2 2 i∈[n]   X   2 2 2 2 ⪯ 4 · ∥X∥F + γ · ∥Y∥F · yi⊤ yi = 4 · ∥X∥F + γ · ∥Y∥F · Y⊤ Y. i∈[n] Similarly,   2 2 E[RR∗ ] ⪯ 4 · ∥X∥F /γ + ∥Y∥F · X⊤ X. In summary, m2 (R) = max{∥E[R∗ R]∥op , ∥E[RR∗ ]∥op } n    o 2 2 2 2 ≤ 4 · max ∥X∥F + γ · ∥Y∥F · Y⊤ Y op , ∥X∥F /γ + ∥Y∥F · XX⊤ op 2 2 = 4 · ∥X∥op ∥Y∥op · (srank(X) + srank(Y)) . 15 KDEformer: Accelerating Transformers via Kernel Density Estimation P xℓr ·yℓr 1 Finally, we note that, from the way the sampling matrix was constructed we have X⊤ Π⊤ ΠY = m = Rm . r∈[m] pir Thus, by invoking Lemma B.1 we find that for t = ε · ∥X∥op ∥Y∥op we have,   h i −mt2 /2 1 Pr Rm − B op ≥ ε · ∥X∥op ∥Y∥op ≤ (q + d) · exp ≤ . m2 (R) + 2Lt/3 poly(n) This completes the proof of Lemma 3.2. B.2. Proof of Theorem 3.3 Theorem 3.3 (Correctness of Algorithm 1). For any matrices Q, K, V ∈ Rn×d , any ε > 0, and number of samples m = Ω ε−2 log n · (srank(D−1 A) + srank(V)) , given access to a primitive WE XP KDE as per Definition 3.1, Algorithm 1 e ∈ Rn×n and a sampling matrix Π ∈ Rm×n which satisfy Eq. (1) with probability at least outputs a diagonal matrix D 1 . 1 − poly(n) Proof. First, note that all entries of D−1 A are positive and the sum of entries of each row of this matrix equals 1, so by the Gershgorin circle theorem D−1 A op ≤ 1. On the other hand, D−1 A · 1n = 1n , so we have D−1 A op = 1. We will use this fact in the rest of the proof.   Q K e = diag(α). Thus, as we Now note that Algorithm 1 computes α = WE XP KDE d1/4 , d1/4 , 1n , 3ε in line 3 and lets D e ⪯ (1 + ε/3)D. showed earlier, by Definition 3.1 and using the fact that entries of D are positive, we have (1 − ε/3)D ⪯ D −1 e satisfies Eq. (5). So, using this inequality along with the fact that D A = 1, the diagonal matrix D op √  √ 2·Q 2·K Next, let us consider the vector β = WE XP KDE d1/4 , d1/4 , u, 1/3 computed in line 4. For ease of notation, let e −1 A. By Definition 3.1 and using the definition of ui = 1/α2 in line 3, we have, X⊤ := D βj ∈ (1 ± 1/3) · X i∈[n] ui · exp  i  2 √ ⟨qi , kj ⟩ = (1 ± 1/3) · ∥xj ∥22 d for any j ∈ [n]. 2 Also, note that γ which is computed in line 2 of the algorithm is equal to γ = e −1 A (1 + ε/3)D, we have γ ∈ (1 ± ε/3)−1 · γ̃, where γ̃ := D in line 5 satisfies, 2 op ∥D−1 A∥op ∥V∥2op 2 e ⪯ . Because (1 − ε/3)D ⪯ D / ∥V∥op . Therefore, the distribution {pi }i∈[n] computed 2 2 2 1 ∥xℓ ∥2 + γ̃ · ∥vℓ ∥2 2 2 . 2 ≥ 4 · ∥X∥F + γ̃ · ∥V∥F j∈[n] βj + γ · ∥V∥F pℓ = P βℓ + γ · ∥vℓ ∥2 e −1 A) ≤ 2 · srank(D−1 A). Therefore, we can invoke the AMM result from Lemma 3.2 Furthermore, note that srank(D ⊤ −1 e with matrices X = D A and Y = V and use the precondition of Theorem 3.3 about the number of samples    −2 −1 −2 −1 e m = Ω ε log n · (srank(D A) + srank(V)) = Ω ε log n · (srank(D A) + srank(V)) to conclude that the sampling matrix Π computed in lines 6-7 satisfies the following with high probability in n: e −1 AΠ⊤ · ΠV − D e −1 AV D op ≤ ε e −1 ε ε D A ∥V∥op ≤ D−1 A op ∥V∥op = ∥V∥op , 4 2 2 op e −1 A where the second inequality above follows from the fact that D op ≤ 2 · D−1 A op . The above inequality shows that Eq. (6) holds with high probability in n. Thus the theorem follows from combining Eq. (5) and Eq. (6) using triangle inequality. B.3. Proof of Theorem 3.5 Theorem 3.5 (Approximate Attention with Spectral Norm Bound). For any matrices Q, K, V ∈ Rn×d , any ε > 0, and e ∈ Rn×n and a given a fast Gaussian KDE as per Theorem 2.1, there exists an algorithm that outputs a diagonal matrix D 16 KDEformer: Accelerating Transformers via Kernel Density Estimation  sampling matrix Π ∈ Rm×n with m = O ε−2 log n · (srank(D−1 A) samples which satisfy Eq. (1) with  + srank(V))   1 probability at least 1 − poly(n) . The runtime of this algorithm is O m + nd · C K , Q ,1n ,ε,τ + C √2·Q , √2·K ,v,1,τ , d1/4 d1/4 d1/4 d1/4 P  −2 √1 ⟨qj , kℓ ⟩ where vj = for j ∈ [n] and C K , Q ,1n ,ε,τ , C √2·Q , √2·K ,v,1,τ are defined as in Eq. (9). ℓ∈[n] exp d d1/4 d1/4 d1/4 d1/4  Proof. It suffices to run Algorithm 1 with some m = O ε−2 log n(srank(D−1 A) + srank(V)) samples and invoke Algorithm 2 for the calls to WE XP KDE made in lines 3-4. By Theorem 3.3 and Theorem 3.4 along with union bound, the e of this procedure satisfy the desired condition of Eq. (1) with probability ≥ 1 − 1 . outputs Π and D poly(n) e through invoking WE XP KDE (i.e., Algorithm 2) in line 3 of Runtime Analysis. By Theorem 3.4, the time to compute D     √ √ Algorithm 1 is O nd · C K , Q ,1n ,ε,τ . Furthermore, time to run WE XP KDE in line 4 is O nd · C 2·Q , 2·K ,u,1,τ , d1/4 d1/4 d1/4 d1/4 where u is the vector computed in lines 3-4 of Algorithm 1. On the other hand, by Theorem 3.4, vector u satisfies 1 1 3 2 vj ≤ uj ≤ 2 vj for all j ∈ [n] with probability at least 1 − poly(n) , where v is the vector defined in the theorem statement. Thus, using the definition of C √2·Q , √2·K ,u,1,τ in Eq. (9) we can show that the aforementioned runtime is bounded by d1/4 d1/4   O nd · C √2·Q , √2·K ,v,1,τ . d1/4 d1/4 Finally, the time to generate m samples in line 6 of Algorithm 1 is O(m + n), using the sampling method developed by Hagerup et al. (1993). The total runtime is obtained by summing up these terms. B.4. Proof of Corollary 3.6 Corollary 3.6 (Simplified Runtime for Bounded Diameter Datasets). For any datasets Q, K with diameter √ 2 maxi,j∈[n] ∥ki − qj ∥2 = γ d log n for some γ > 0, the runtime of Theorem 3.5 is upper bounded by  O m + nd · nτ (1+γ) + ε−2 nτ (1+γ/2) , which is strongly sub-quadratic in n. In particular, if γ = o(1), the runtime is  bounded by O m + ε−2 d · n1+τ +o(1) . √ 2 Proof. First recall that the diameter of the datasets Q, K is maxi,j∈[n] ∥ki − qj ∥2 = γ d log n for some γ > 0. For any √ 2 i, j ∈ [n], using the fact that ∥ki − qj ∥2 ≤ γ d log n, we have, exp  1 √ ⟨kj , qi ⟩ d  = exp  −1 √ ∥kj − qi ∥22 2 d  2 · exp 2 ≥ n−γ/2 · exp ∥kj ∥ + ∥qi ∥ √ 2 d 2 ∥kj ∥ + ∥qi ∥ √ 2 d ! 2 ! . Therefore, summing the above inequality over all j ∈ [n] gives, X j∈[n] exp   X 1 √ ⟨kj , qi ⟩ ≥ n−γ/2 · exp d j∈[n] 2 2 ∥kj ∥ + ∥qi ∥ √ 2 d ! . The above inequality holds for every i ∈ [n]. This inequality implies that the following set is empty for any µ ≤ n−1−γ/2 ,     P √1 ⟨kj , qi ⟩   j∈[n] exp  d2  i ∈ [n] : P < n · µ = ∅. ∥kj ∥ +∥qi ∥2   √ exp j∈[n] 2 d 17 KDEformer: Accelerating Transformers via Kernel Density Estimation Thus, C K , Q ,1n ,ε,τ defined as per Eq. (9) is bounded as follows, d1/4 d1/4    √1 ⟨kj , qi ⟩  exp j∈[n] d   < nµ C K , Q ,1n ,ε,τ = min ε−2 µ−τ + i ∈ [n] : P 2 2 ∥kj ∥ +∥qi ∥ µ>0   d1/4 d1/4 √ j∈[n] exp 2 d   P ≤ ε−2 · nτ (1+γ/2) . Similarly, because vj > 0 for every j ∈ [n], we can show that, for any i ∈ [n], X vj exp j∈[n]  2 √ ⟨qj , ki ⟩ d  ≥n −γ · X j∈[n] 2 vj exp 2 ∥qj ∥ + ∥ki ∥ √ d ! . As a result, the following set is empty for any µ ≤ n−1−γ ,     P √2 ⟨qj , ki ⟩   v · exp j j∈[n]  d2  < n · µ = ∅. i ∈ [n] : P 2 ∥qj ∥ +∥ki ∥   √ j∈[n] vj exp d So, C √2·Q , √2·K ,v,1,τ defined as per Eq. (9) is bounded as follows, d1/4 d1/4    √2 ⟨qj , ki ⟩  v · exp j j∈[n]  0   d1/4 d1/4 √ j∈[n] vj exp d   P ≤ nτ (1+γ) . Therefore, the total runtime of Theorem 3.5 is bounded by       √ √ O m + nd · C K , Q ,1n ,ε,τ + C 2·Q , 2·K ,v,1,τ = O m + nd · nτ (1+γ) + nτ (1+γ/2) /ε2 , d1/4 d1/4 d1/4 d1/4 which completes the proof. C. Additional Experiments C.1. BigGAN Image Generations Images in Fig. 7 are randomly subset from 2, 000 generations from BigGAN (Yuan et al., 2021)1 with the exact attention computation and its various approximations including KDEformer (our), Performer (Choromanski et al., 2021), Reformer (Kitaev et al., 2020) and ScatterBrain (Chen et al., 2021b). One can observe that our KDEformer generates more natural and realistic images than other methods by a large margin, and in many cases it is even better than the exact computation. This means that it has much less running time and memory, but it has produced a higher quality and more realistic image in the end. Also, note that the hyperparameters of our approach were not fine-tuned. 1 https://github.com/huggingface/pytorch-pretrained-BigGAN 18 ScatterBrain Reformer Performer KDEformer Exact KDEformer: Accelerating Transformers via Kernel Density Estimation Figure 7. Images generations from the pre-trained BigGAN with the exact attention (top) and drop-in replacement with its approximations including our KDEformer (second row), Performer (third row), Reformer (fourth row) and ScatterBrain (bottom). 19