today erwin and i are going to be giving a talk on scaling transformers through sparsity and the kind of sparsity we're going to be talking about today is the kind where you know each input can get you know either a different set of weights or have a different amount of computation applied to it or do you want to start it off yeah so i guess the overall motivation for this line of work is that um you know the community has kind of realized that scale uh is perhaps one of the most important access to to focus on for obtaining strong performance and there's almost like this sort of ongoing arms race right now with different labs and different institutions um sort of competing for training the largest uh uh you know models and so uh maybe this dates back from early uh 2020 with a paper from open ai called scaling laws for neural language models where they find that model performance follows the predictable overall scale sort of as a parallel with uh model size in terms of either compute or also just uh you know uh parameters and so this uh this scaling law kind of generalizes over multiple orders of magnitude and you know that gives us the confidence that if we are to train very large models uh you know we can expect you know certain performance uh just by extrapolating extrapolating these uh scaling laws so in that paper they also find the interesting observation that basically larger models are more sample efficient and so you know if you have a fixed compute budget uh you can sort of you know you you can predict what what is the size what is the optimal model size or fixed compute budget and um the the overall observation is that uh you know you you'd rather train very large models from for less tests and train smaller models for more training steps and so these models are scaled um you know through [Music] basically those so the paper focuses on dance models right where you just increase the model dimensions um but they're not looking at sparsely and so is is a new dimension that you can use to scale architectures um you know and this is sort of the the focus of the talk and so the sparsity we are where we're mentioning here is basically you will have sparsely activated weights based on the network inputs so every input will go to a roughly similar amount of computation but will be applied different weights and so this dates back to 1991 with a paper called adaptive mixtures of local experts and was recently revisited by noam chazier and colleagues at google brain with lstms where they replaced so sort of the feed forward networks in lstms with mixture of experts and so the way this works there roughly is that you will have multiple experts each implementing you know a small network or in that case i think just a dense matrix multiplication um and so you have an additional getting network um shown in green here that uh outputs a probability distribution of our experts that each token should be sent to so this probably distribution is computed as a softmax and once you have it you select a few experts so there are different strategies maybe we'll talk about it later on and the output is simply sort of the weighted mixture of all selected xbox outputs so that they've been pretty uh successful in [Music] primarily in translation but there was some you know you know some complexities that another brother used in nlp and so the switch transformer paper um addresses uh some of those and we will be discussing how to you know how to fix training disabilities or reduce communication costs and and uh reduce my complexity all right barry you want to go sure yeah so yeah so one kind of approach that we're gonna have uh for sparsity is the uh switch transformer which is kind of like a simplified mixture of expert variant along with some other improved you know training and fine-tuning techniques that allow it to you know be stably trained and also you know perform better when fine-tuned done a lot of downstream tasks and so yeah so the switch transformer kind of uh model works as the following so you have some transformer model that has you know self-attention and feed forward layers and the idea is that we replace maybe one every two or one every four feet forward layers with a switch transformer layer so you can see on the left is like one kind of layer block which is self attention then add normalize then a few forward layer then add normalize and in this case we're replacing the normal feed forward layer with the like switch layer and we can see an illustration of this on the right so on the right we can see that the the layer has two inputs um one is the token more the other is the token parameters and we can see that these you know embedding representations will get sent to a router which is exactly how it works in the mixture of expert so the router is basically just going to be you know getting a distribution over all of the experts so in this case we can see that the the highest probability is going to the like the expert number two out of the four experts and then the the right token is actually having the most probability on the first uh feed forward weight which is like the first expert so yeah we can see here that like what we're going to do is in the switch transformer which is very simple is just send it to the highest probability expert and so here we can see where the adaptive computation lies where we'll have four sets of weights there's some shared uh weights and computation across all the tokens for example the self-attention layer is computed exactly the same for the more token and for the parameters token but in the sparse switch layer we can see that like actually the inputs are while having the same amount of floating point operations applied to them actually have different weight matrices uh next slide yeah so that's the kind of high level i deal with switch transformer is that you know instead of sending a token to multiple different experts which can also increase the communication costs as i'll go into a little bit later it also just like significantly kind of simplifies the algorithm by just only sending it to one expert so for the improved training methodology we focused on three different things to help improve the training of sparse models the first was selective precision which like allows these sparse models to be trained in lower precision formats which is incredibly important most of the models we train we really don't want to be using float32 because it's just slower to compute it also when you're communicating tensors across different processes and stuff it's twice the is twice as slow just because there's twice as many things also we have some initialization tricks and some training tricks as well for allowing them to be trained more stably especially as the models grow in size which is like a new initialization method along with like a a change to the learning rate schedule and third since that our models have so many more parameters we do notice like definitely different overfitting dynamics especially once we fine-tune these models that have been you know pre-trained on all of the internet on these small tasks with maybe only 50 to 100 000 examples that they can be much more prone to overfitting so we also look at some custom uh you know regularization to help prevent some of the overfitting that we observe and finally we also talk about this differentiable load balancing technique we make which kind of allows you know each expert to roughly get the same amount of tokens because you know this is very important especially given that we're you know want the stuff to be efficient on hardware we want roughly each expert to have similar amounts of tokens sent to it and so to kind of encourage this we tack on an additional like load balancing loss along with our cross-entropy loss that we're training with uh next slide okay so here i'm gonna go into selected precision so yeah again so like when we're training large models it's really important that we should be able to train them in lower precision formats so instead of each you know weight being an activation being 32 bits we want to shrink it down to 16 bits and we use like the b float 16 representation and what we found out of the gate is that you know these models are just unstable especially the sparse models are much more unstable than the dense models in terms of like you'll train it for 10 20 000 steps and then the losses would just diverge this was something that we you know frequently encountered and so one key thing that we found is that basically you need to be casting a part of the computation in um float32 for these models to be able to be trained um stably and the the key component that we found that you need to cast is the like router computation and essentially you know we can go into the technical details a little bit more later but basically anytime that there's like these exponentiation functions it's very important that we are you know having higher and higher precision because of round off errors that can then drastically change the output of some kind of you know exponentiation function so for example like if you have an exponentiation function and you change it by 0.1 or 0.2 or 0.3 this can drastically change the output of like exponentiating it especially depending on how large the input is so yeah so this was like a very important thing and it basically doesn't change the compute at all and allows the models to just be significantly more stable next slide so the the second thing we looked at is also the initialization scale so like the standard way that we were initializing these models we found to also just make the models much more prone to being unstable and or just performing worse so one thing that we did that we found was very effective was to just simply make the initialization scale much smaller and when we did this we found that you know the quality just like drastically improved it was like a very simple fix uh next slide and the third thing i mentioned where since we noticed that these models are much more prone to overfitting since they just have significantly more parameters is that we also use much more dropout for the expert layers only so here we can see we took we have like you know the t5 base which is a dense model and then we have a bunch of different switch variants on that and we found to be the most effective on these four different fine tuning tasks was just to really significantly increase the dropout rate inside the expert layers and we found that this was pretty effective for combating the overfitting next slide yeah we have a question oh awesome yeah okay let me take a look do you want to go ahead yeah i can ask um it was just in reference to that previous table where you have throughput and precision um it just seemed surprising to me that you could match this 1390 number we're using selective precision it seems like i would expect it to be like something in between yeah so it essentially comes down to the fact that like there's maybe a little bit of noise sampling with the speed and the only part we're casting is the router which is you know maybe like so it's such an insignificant portion of the computation and there's zero communication there that it's essentially like a free operation in the network so whether you cast the b flow 16 or flow 32 it doesn't actually impact the speed at all within the precision so we can actually measure the speed and also these architectures only use fast layer uh ones one every four layers and so yeah essentially the float32 part is kind of very negligible in the entire yeah it's like for example i think like off the top of my head it's like 140th the computation that would cost for you to do the first like like weight matrix multiply in like a dense relu dense layer or something so it's a very very small part and yeah we're not using them very frequently like everyone mentioned as well yeah and then you know just like a quick point in this like i won't go into some of the technical details but yeah we definitely you know since we're training these things on hardware and we really like i think a big part of the mixture of experts paradigm is that these things are designed such that it maps really efficiently to hardware so we want to be doing dense matrix multiplies and for this to work really well we we also want to be able to have you know roughly equal amount of tokens going to each of the different experts and what like i think the this isn't that sensitive to the load balancing formulation like we tried a few things a lot of them worked but yeah essentially you definitely want some kind of load balancing loss added on when using sparsity yeah next slide yeah erin go ahead yeah so so the frameworks with the library we use uh rely on static shapes uh for okay yeah so we so xla so the compiler um for uh tensorflow and mesh tensorflow expects static shapes for tensors however the computations in switch transformers are dynamic because you know because of the router right like different inputs will be routed to different experts and so we need to specify ahead of time how many tokens will be sent to each export and so we will introduce this expat capacity um hyper parameter to specify that and that's going to be a static number which says how many tokens each expert can process and so in practice we instead parametrize this by having a quantity called the capacity factor so we have an example here um so you know so the bottom row is okay so is a bunch of tokens on one device and then you need to sort of route those tokens to multiple devices or multiple exports so if too many tokens are routed to a single export some tokens will be dropped because as we said like export had a fixed capacity so that's the example on the left where the capacity factor is one and that basically means that the total um there's no like extra buffer for uh rolling tokens um so instead of that we can use a capacitive factor that's larger than one so on the right you have an example with 1.5 um so that means that now each export has like sort of three slots that can process three tokens and so that prevents token dropping because we have more capacity but the issue is that this means higher you know this more expensive communication across devices yeah okay so that was it yeah go ahead oh yeah so yeah so one thing that we also experimented with was this um method called no token left behind and the idea was the following so since we have to have like you know a fixed batch size for each expert and there can be token dropping we kind of we're thinking that hey yeah having tokens dropped or like you know having some tokens not having any computation applied to it is probably hurting the model performance so what if we do a multi-stage routing procedure so first you do the normal routing where it's like you send each token to its highest probability expert but then any drops tokens you then send to their second highest probability expert and so forth and so on where you can basically repeat this process to guarantee that no tokens are being dropped interestingly actually this approach didn't empirically improve model performance of anything it actually kind of hurt it and we thought that was actually very interesting and i think the intuition is that you know once the model learns it wants to send a token to one expert like it really wants to have that computation applied to it and just applying some other computation doesn't you know have at all the same property along with it actually maybe being potentially detrimental so yeah we thought that was pretty interesting as we were very optimistic this would potentially you know get improved performance but it ended up not really making a difference and we found this quite surprising we have a question from um i think it will actually kind of like address literally the last point that you brought up um i think when i think about like a mixture of experts um usually like they specialize in like different things right so i think it's like um just like a lot like i was just wondering um like if you send it to like the second best or whatever um like what if like all of your tokens would be particularly good for like one expert and then you only like process let's say like 20 of your tokens so that ends up being better than rerouting them to anything else exactly yeah so yeah even if you're dropping a lot of tokens it's not beneficial to be sending them to the second third or fourth best thing and one actually interesting property that we you know noticed about these models is they're surprisingly robust to token dropping especially during fine tuning so yeah so in the standard paradigm what we'll do is we'll pre-train this thing we'll have some load balancing loss which makes the the tokens pretty balanced actually but then during fine-tuning where it's like we really want to fine-tune it on a specific task we actually studied this exact question and we were studying does it help to have a load balancing loss during fine tuning or not and so if you have the load balancing loss yeah that kind of is encouraging you know for the specific task we want to try to have you know all the experts be used versus turning it off whereas there's definitely some you know prior specialization and it's actually much better to just turn the auxiliary loss off and even if it's like you know 60 to 70 of the tokens are being dropped that actually performs much better than you know having all the tokens balanced but doesn't a load balancing loss encourage basically all the experts to learn very similar weights and then just randomly assign your tokens because then it does matter to which expert stuff is being sent to so when we use the load balancing loss like the routing mechanism is definitely learned so the model definitely is encouraged to you know choose an expert that it wants to send it to for good right but like if all the experts learn the same weights then the the router learns basically oh it doesn't matter where you send it to so if you encourage load balancing you encourage technically that like you want any loss to fit with any expert right i mean that's maybe the extreme behavior if you have a very high sort of load balancing loss coefficient but in practice that coefficient is kind of tuned and we observe that for you know small enough values um the router still learns like cement like meaningful routing yeah because it's like a balance between this like you know cross entropy loss and this load balancing loss and so on one hand yeah you definitely want to encourage the model to be balanced then on the other hand you also want to just get good empirical performance and yeah the model is able to definitely like on one hand learn and specialize the experts where they have different weights such that it's like you know definitely it expects certain tokens to be sent to certain aspects but on the other hand still be reasonably balanced so that the models are officially run on like modern hardware exactly we also have a question from the classroom so question the question that i want to ask is it seems to me like this is a very experimental talk we're talking about floating point precision we're talking about different approaches and currently work well and whenever we're dealing with your clients there's a question of what is the research question and i feel like i miss that so what are we trying to answer with all these experiments yeah i think the i think the high level research question is like you know can we you know create models that are you know like doing adaptive computation from the standpoint of like you know can we try to make models more simulate the dynamics that we think models should you know most naturally use which is different inputs that have different amounts of computation applied have different weights applied to them you know and basically all of this basically we're trying to research and like figure out how can we create like a new framework for these models to be trained as opposed to their dense counterparts that you know for every input are always having the same exact computation applied so that's interesting because when you say the same exact computation applied one might imagine that um like to me the immediate thing is about how long to deliberate about something what i mean by that is if we want to have variable length computation you could imagine that i could have a short amount of computation or it could have much longer computation but this idea of like why then do we instead consider the dimension of different computation i mean assuming of course that these experts do indeed learn different things which i think you'll get to in yeah so yeah why do we immediately jump to thinking about specialized experts as opposed to thinking about variable length computation so yeah so this is actually we actually go into some variable length computation stuff later in the talk and i feel like they're both actually just important axes that should both be pushed on i think i guess yeah i guess it's kind of you know i guess yeah i'm not freezing my question but what i'm trying to understand is you're thinking about why did you decide to attack this one first i want to understand why your team chose to go this direction first yeah absolutely so i think that one empirically it seems that sparsity has led to better empirical results in the field of deep learning than adaptive computations so far and i think the way that we use these things maps really well to our modern hardware which is also very promising and i think the way we were kind of looking at it is like sparsity is like a first step towards doing more interesting and general adaptive computation where and we're and you know because i think it's like you know this stuff is complicated and typically starting from something that works well is better than necessarily like you know you know trying something that's not necessarily as proven out and then trying to like get it to work really well so i think we're kind of starting from sparsity which like you know gnome chazier and others got to work really well in the context of lstms we were kind of interested in you know let's pour some of this to transformers let's get it working really well and then let's slowly start expanding towards a lot of the other natural questions that you mentioned whereas like okay whereas instead of you know different weights per core let's also maybe have a different computation per core and all of this so that's i guess how we were kind of building the natural like you know build up and progression of our research got it cool thank you what do you think or anything else to add um yeah i mean i guess i kind of see adaptive computation and sparsity as you know related but separate things so you know especially is more like different parameters for each example and adaptive computation might be more different amount of flops and we have some of that with the token dropping but that's kind of uh no that's not the the main um domain motivation uh definitely as barrett mentioned uh i would say you know no one really has figured out adaptive computation yet for deep learning and what one reason is because we have this uh you know accelerators right expect expect like sort of you know we need to work with like batch like data parallelism right so um and all of our accelerators and our frameworks use this spmd paradigm where you're kind of supposed to apply the same computation to to examples um and so if you look at the literature you have you know walks like universal transformers where they replace the feed forward in the transformer by um just a recurrent weight and so it's kind of like an lstm on each token and the lcm can stop at different times based on some criteria but the way these things are implemented is just through masking um because it needs to be implemented in the spmd programming style um and so definitely sparsity was kind of like easier to get to work first and also there were some prior results with lstm so in terms of like the first question you know sort of what's our research question here it's just like oh can we design more efficient models and sparsity is this new axis that hasn't been explored that much and yeah i think that you know i i'm happy with just that being the research question great okay um yeah so next slide yep oops yeah again so kind of putting it all together so the switch transformer layer selects an expert like just the top expert and then incorporates a bunch of the general sparse model improvements to you know allow it to fine-tune better allow it to you know be more regularized allow it to you know be trained with lower precision formats and a lot of like technical details to just get them training and working well um yeah so one thing that we also wanted to do was uh a comparison between like top one and top two routing since top two routing was kind of the you know most popular technique and so here we can see we have two different dense models trained of different sizes and we're going to be looking at like the the pre-training like negative log perplexity so um yeah the bigger the number the better so next slide so here so and what we're going to be doing is we're going to be studying them at different capacity factors so a capacity factor of 2.0 basically means that there is enough buffer for two tokens to be sent to every single expert and we're going to be comparing like top one versus top two routing and also comparing their speeds along with their like time to get some like threshold quality uh okay yeah so here we can see in the capacity factor 2.0 case that um the moe models outperform switch transformer which makes a lot of sense like since switch transformer is only you know sending uh like a top one token to each expert the mixture of expert is sending um you know two tokens so that makes sense that this extra buffer will be like disproportionately beneficial for the mixture of expert models and so we noticed that and next slide all right yeah next now when we so the really interesting parts for the top one routing becomes when we lower the capacity factors so having a high capacity factor is bad for many reasons one of which is it really incurs more of these you know communication costs for sending tokens to the correct experts it also incurs more compute costs and also incurs like a lot of memory overhead so if you can get this lower it's it's usually like a very very good thing and so what we see here is that switch transformer actually outperforms mixture of experts when you have like a lower uh capacity factor and we can see that the time to quality threshold we um you know yeah we get there much quicker and so even across the 2.0 and the 1.25 capacity factors like the kind of pareto optimal thing we saw in our setup is to use switch transformer at a lower capacity factor just due to the fact that while the quality is worse a little bit worse on a step basis it's just like much faster to run so it's kind of the pareto optimal decision uh next slide and we could also be seeing that like for capacity factor 1.0 again we can see that this really disproportionately benefits switch transformer and is even better for on a pareto standpoint than the 1.25 capacity factors and interestingly since you know moe also does like a little bit more computation we can also just increase the uh amount of compute done elsewhere in the model and we can see that that's like a much more efficient allocation of compute so yeah overall our takeaway is that yeah lower capacity factors using up on routing is more pareto efficient than you know using like top two routing at higher capacity factors next slide um oh and you can take it over okay so next we'll look at how sweet transformer scales as a function of the number of exports in the switch layers and so on the right side here you see a plot that shows perplexity versus training steps for different switch architectures ranging from t5 base which is basically no expert or a single expert up to 128 experts and so you see that as we increase the number of experts which also increases number of parameters of sports space parameters you get sort of uh speed ups you know you get increasing speed ups about the dense baseline and are like sort of diminishing returns to uh you know multiplying to you know increasing the number of experts as well so the previous figure was looking at perplexity versus training steps here we look at perplexity versus strength time so that includes you know all the you know additional communication costs when you have more experts or um you know comparing comparing to the dense baseline and so this is for switch bass or then super t5 bass and um we observe 7x up to 7x speedups over t5 bass and so you know just to maybe contextualize these these numbers like you know 7x speedups and deep learning are pretty hard to obtain and so i think this is one of the you know one of the results that um [Music] you know can spark a lot of interest in sparse models even if it's only for pre-training for now like just having that number is like you know maybe there's a there's a significant um there's something significant that can be obtained here okay so spot scaling laws so here we'll look at uh sort of loss versus sparse model parameters which are increased by increasing the number of experts and so similarly to the sort of you know neural scaling wallpaper we observed that as you increase the parameters which the sparse parameters and keep the flaps fixed uh you get diminishing like consistent gains by diminishing gains okay so now we're going to compare export parallelism and molar parallelism so we we introduced sparsity or export parallelism as a new dimension to uh scale models but of course there's the other one for dance model which is simply model parallelism where you know model weights are partitioned across cores once they are above the maximum size that you can feed on a single core all right so yeah parrot has to left this export parallelism here yeah so so essentially what we're doing is is um yeah we're kind of comparing a switch base model versus the the dense base and we're also comparing against a larger dense model that has used um model parallelism and we can see that you know because basically when we want to scale up model size we kind of have two axes that we can either go through we can either increase the number of flops by scaling through model parallelism or increase the number of parameters by scaling through sparsity and so we can see that you know even compared to like you know a dense model that's been scaled up through model parallelism that sparsity is still at the scale a more effective way to scale up the model by you know still getting 2.5 x speed ups over this larger um dense model that was using uh model parallelism cool so uh yeah basically here t5 large is the dance mode that uses other parallelism yeah go ahead okay yeah and so one thing that we also wanted to look at is like you know are these expert models effective if you have like you know really small amount of computer just a small amount of experts so typically when we're designing these models like we have one expert per core but if you don't have like a large cluster to run these things on let's say you just have like a gpu with two cores or something like is having two experts more effective than just like a dense model and the the answer is yes so we can see even pretty good scaling properties even with like a tiny amount of experts which is very very promising for these models to be used even in like much lower compute regimes next slide okay so yeah so yeah so we'll look at um you know what things look like when we use different uh types of parallelism namely expert parallelism to add exports smaller parallelism to sharp model ways across course and also data parallelism which is sort of the dominant paradigm in deep learning at the moment and so you know i guess you know in the previous slides we mostly talked about export parallelism but of course you know dance models and large-scale dance models uh use model parallelism so gp3 and these other large models what they do is that they will simply shout model weights across different cores uh yeah we have a question oh yeah um i just wanted to know because i think there was like i know if you're gonna address later but i think somewhere in a paper it said that um the more experts you have the more sample efficient it gets and i was just like hoping hoping that you could give us some intuition about that because i don't understand why that would be the case so i guess yeah maybe yeah so i guess like you know there's all of this work on larger models are more sample efficient and larger in the context of the scaling law works means like more parameters and more flops as you increase the number of experts there's more parameters but not more flops but the model is still like you know larger in in like you know a similar sense so i guess like building on the intuition that larger models are more sample efficient in my mind it's not necessarily that surprising that these models with more experts that have more parameters are more sample efficient i guess that's my like kind of high level uh intuition for it yeah i would say that's kind of expected that you know more experts leads to better sample efficiency especially if you look at trending step right in our training time okay cool so where are we um yeah so yes so okay so we look at how model weights are split over cost for different scenarios um so data parallelism is the first one so that's kind of the the typical setup that deep learning uh uses um especially for not so large networks which don't require another parallelism and so let me yeah let me explain how yeah i'll just go to the final figure and i'll explain how to look at this figure okay so we have 16 processors which are organized in a four by four mesh right so each dotted line each four by four dotted line here represents a different core and the first row studies how the model weights are split over course and the second row illustrates how data so literally examples and tokens are split over course and yeah and then the final thing to that's required to understand this figure is that each um yeah each color of the shaded squares here uh identifies a unique weight matrix okay so let's start with data parallelism so for data parallelism uh the same model weights are replicated across all cores and the data is simply partitioned of our cause and so that's what um this corresponds to um you know if you like using the the description of the caption the explanation of the caption i just gave so next we have model parallelism that's kind of just like a theoretical example because in practice people always use modal parallelism in conjunction with data parallelism but so if you want to do only model parallelism now you would have a single model way that is partitioned over all cores and your data would just be replicated over all codes instead so now we have modeling data parallelism and that's kind of the typical scenario for large dense networks so in that case model weights are partitioned among a subset of the course to subset of course that process different batches of data and so in that example here we have you know sort of four uh so the first sub-square here means that the model weights are partitioned across four score four cores and um and this is replicated sort of four times for the data parallelism dimension on the data side for model and data parallelism um yeah the data here is replicated across model parallel cores and partitioned across data parallel cores so next we have expert and data parallelism so in that scenario that's kind of similar to data parallelism but now each core will hold a different model weight which is illustrated by the different colors and for the data side the data is simply replicated sorry the layer is partitioned across all cores just like in the data parallelism scenario and so finally we have uh the rightmost column which is um i guess yeah that's the setup used in the switch transformer paper for the larger models and so here for the model partitioning each expert this partition across multiple cores so in that example we have four experts each partition across four cores and the data is replicated across multiple cores and partitioned across data parallel cores so that's a that's a little bit um you know complex too to understand already but the sweet transformer paper has a nice the same figure with a nice caption to explain it and yeah maybe we can um you know about it we can add something quickly about how this is implemented in practice so there's this paper called mesh transformer which kind of extends batch or data parallelism to more general purpose spmd style programming and so different labs have different you know frameworks but this paper kind of lays the foundation for um general spmd distributed computing which is required for training large-scale models and so under the the mesh abstraction basically we have a mesh of processes um which ha and so that mesh has dimensions name dimensions and these name dimensions specify how the tensor dimensions will be partitioned or replicated across the mesh dimensions and so just that simple abstraction um sort of supports you know data parallelism also model parallelism and especially expert parallelism uh at once and so you know i invite whoever is interested to to also check that paper because that's kind of um you know that kind of lays the foundation for understanding these things all right barry's want to go cool yeah so next we're going to kind of talk about like how we take these parallelism strategies and like kind of combine them together to make like a a 1.6 trillion parameter sparse model so next slide so so here so what we ended up doing in this work was we had we um trained two different very large sparks models and we compared them to the largest t5 model so we can see the t5 xxl which is a dense model and it was the largest one trained in the p5 paper and it has around 13 billion parameters and here we list a lot of the model dimensions like the model dff which are just like you know the various sizes and shapes of the of the tensors and stuff the number of layers the number of heads and importantly we also mentioned the negative log perplexity um at step 250k and at 500k and so yeah so we designed two sparse models to test and i have to test like how scaling versus sparsity versus scaling versus sparsity and flops work so first let me talk about switch xl so that has the same amount of flops per token as t5x xl but has 64 experts and this leads it to have around 400 billion parameters and we can see that on a step basis it actually performs quite well and outperforms the t5 xxl by like quite a good margin interestingly though our the third model we designed switch c which has 1.6 trillion parameters but has significantly fewer flops almost 10 less flops per token than either of the above two models so it's really trading by reducing flops that have way more sparse parameters and we can see on a step basis the the switch c model uh works well but not not as well as actually the higher flop model but on uh like a kind of a pareto axis where we're looking at tpu hours on the x-axis and not step the switch c model actually outperforms them both by like a pretty large margin so for pre-training performance we're seeing that actually just like having a lot of sparsity and less flops is actually um can be quite good next slide yeah and so yeah this so again those two sparse models are kind of really trying to get at this hypothesis that actually gnome shazier had which is you know that you know parameters are good for more knowledge reasoning and compute aka flops is good for intelligence and so we're going to kind of try to get at that by taking these different sparse models and then fine-tuning them on uh different tasks some of which require more like knowledge and then others which require more of like reasoning for whatever like handwave definition we want to give that so yeah so for a fixed oh go back so yeah so for a fix oh can you go back to the previous slide oh yes sorry okay so for a fixed quality on an upstream pre-training task um yeah do parameters independently matter so we're going to look at two tasks here one of which is super glue which is kind of our like reasoning task and then another is like trivia qa which is like some knowledge task where it's like you just give it a question you have it output an answer okay and so here we're going to take a look at super group quality so we can see on the x-axis is the pre-training performance and the y-axis is the super glue score after fine-tuning and interestingly we can see definitely that the sparse models definitely are for a fixed pre-training perplexity do worse on fine-tuning this can be especially noticed at like the upper right portion of the plot where the dense models are definitely fine-tuning better than the their sparse counterpart next slide interestingly when we study it on the more knowledge-heavy tasks the sparse model for a fixed uh pre-training perplexity does disproportionately well so you know for a model that roughly has the same perplexity we're getting like really large boosts for these knowledge-heavy tasks so this is pretty interesting and it also really you know shows some of the dangers of comparing only on your pre-training metrics so these models you know can have the same exact preaching metric but very different um you know properties when fine-tuning them on different tasks next slide and interestingly so yeah all of the switch models here are the um are just like you know various models that have still a good amount of flops but the red model is actually the 1.6 trillion parameter sparse model that has you know very few flops but a lot a lot of parameters and we can see that as the red dot here and it does actually disproportionately bad compared to other sparse models that also have pretty good perplexities and so yeah it's uh it's definitely very interesting and it shows that you know for models during pre-training that have a lot of sparsity they definitely suffer on some of these more reasoning heavy metrics but do disproportionately well for more of these knowledge-heavy tasks uh next slide yeah and so here we can see it as just like a huge outlier for a pre-trained perplexity doing like just incredibly well on this uh downstream question answering task next slide yeah okay so also you know one thing that we were gonna do is just look at the fine-tuning properties of sparse models across like a few scales and just see how they perform next slide yeah and so here we try two different models one is um t5 bass and then we make a flop match sparse counterpoint and when they say flat matched it's like you know each token will have the same amount of flops but now we just have experts so we do this for both base and large and we see that actually across almost all tasks besides two arc tasks the sparse models perform quite well which is which is definitely promising so we are seeing that these models are pretty robust they pre-train well and then they also fine-tune well when scaled appropriately by scaling up both the flops and sparsity whereas you know the negative results we've really seen are like yeah when you just have a huge amount of sparsity and not too many flops next slide yeah and one also thing we wanted to look at was uh the multilingual training so we were previously studying all of this on like english only and we also wanted to see how sparsity helps in the multilingual setting because you know we also felt like this would be a very natural place for sparse city to work well where potentially experts could specialize across languages and we do see strong results so on 91 of the languages i think of like around 100 languages we see over like at least the forex speed up over the mt5 dense model next slide you want to go ahead uh no go go ahead okay yeah so another thing we wanted to talk about was distillation so one downside of these sparse models is that they'll have a lot more parameters which means that you know if you're serving these things or something you either need like high throughput use cases or you need to maybe distill it back down into like a smaller dense model so here what we do is we look at like the t5 base and switchbase and we look at its pre-training performance and then we go through some ablations of different distillation techniques and find that like with the best techniques we can keep around 30 of the quality improvements of sparsity while distilling it back down into its dense counterpart so next slide yeah and then we kind of study this across multiple scales and again we see like around like 30 to 40 percent of the gains can be um like you know kept when going from a dense mod when going from you know a sparse model distilling it back down into like its flop match dense model so you can get you know get rid of up to 99 of the parameters and still keep like around 30 of the improvements which is very promising next slide wait i'm sorry yeah all right sorry about that can you say that last sentence again you said that you can keep the benefit 30 of the teachers benefit yeah basically so yeah you you yeah exactly yeah yeah so we're looking at like yeah you train a sparse model and then you distill it back onto a dense model and you're versus training a dense model from scratch and like you look at the gap between the sparse and dense model from scratch versus the the the gap between the dance and the distilled dance model what do you mean you go forward oh yes yeah oh yeah maybe let me just do like a quick high level summary again so yeah what we'll do is for our comparisons is we'll train a dense model from scratch we'll train a sparse model from scratch and then we'll also run a third experiment where we distill that sparse model down into a dense model what does distilling mean like we're basically trying to match the uh like the teacher's logits like the kind of standard thing of like you know like matching the like either the logits or like the soft probabilities for each token or something like that okay if i can jump in with my question so what i'm struggling with is how do i interpret the line that says percent of teacher and performance yeah okay so it's basically looking at the like the gap between the dense and sparse model so we'll have the dense model get some performance we'll have the sparse model get some performance and then the um the dense model is still from the sparse model be somewhere in between that that range and we're basically saying it's 30 through that range so it's like in like a zero one interval it's like point three of the way from the dense to the sparse model i see so this is not saying that the percent of teacher performance does not mean that if the teachers say it's if we use the teacher's uh uh guesses or predictions as the ground truth this is not saying that the distilled model gets matches with the the teacher 33 of the time no no exactly it's basically saying you get like 30 of the the quality improvements yeah exactly okay cool um and then if we can back up the slide i had a different question but i didn't want to interrupt um when we were talking about all of these different t5 bases and then also on this a few slides before this i don't know that much about t5 i'm curious to know when t5 is trained is there a weight penalty in the lost function is there a weight decay turn no there's no weight decay trained with any of those sparse or dense models i see so out of curiosity then how do dense models perform compared to the switch model if you add some sort of weight regularization that incentivizes getting rid of useless weights oh so some kind of like maybe like l1 term or something like that yeah so i'm just wondering like how much of because here we're talking about the benefits of sparsity and i'm wondering how much of this benefit from sparsity is due to the fact that just some of this i mean effectively what the switch model is doing if i understand correctly maybe i don't what i understand is that the switch model the feed forward layer it's just like you you fixing the weight to be zero that's what it means to be smart well actually we're kind of really trying to like inject more weights so we're actually kind of trying to do it's a little bit maybe like paradoxical because we're saying switch transformer but our idea is to be like hey we actually want to just have significantly more weights not less it's kind of like you you would zero out weights but within a much larger weight matrix if that makes sense i see yes and so to me it seems like a relevant bass line to just ask what happens if i have the dense matrix but i incentivize it would say an l1 or l2 penalty on the weights and i would be curious to know how that compares yeah we didn't run this but also that kind of gets rid of weights for the dense model so if anything so yeah yeah and the last point is like if you just add like an l1 penalty loss um you're not gonna have structured sparsity whereas like here we you know it's not random weights in your giant weight matrix there are zero that right it's like really like blocks depending uh like blocks corresponding to each expo right so that that structure allows the the whole like communication stuff and and that that's yes what leverage is the fact that you have multiple calls and so on right so i i totally agree with that block structure and and that's what i'm trying to say is that the switch has this very rich it's it's not just sparse it also has this rich structure and what i'm trying to do in my mind is disentangle is the sparsity what's offering an advantage or is this additional structure that you built in is that what is the performance so that's why i'm asking so so the the the block structure is what enables to leverage the fact that you have multiple calls like if you if you didn't have that block structure you'd still have to route to everything and so you have more communication costs and so on so um and then your first question was what sorry i'm not actually sure if there was a question i guess what i'm trying to say is i'm trying to disambiguate yeah anyways but i agree it's a little bit weird because sparsity kind of there's like a spectrum of meaning for sparsity right was like for example compression and like model pruning is a form of sparsity but also uh sweet transformer and emily also referred to as sparsity and that kind of related but definitely they're aiming at different things so this is a really interesting idea of it's sparse but you have more parameters i'll have to think about it more thank you yeah yeah like it's kind of like spots within this like giant weight matrix which is yeah yeah yeah i hadn't appreciated that so i appreciate you you pointing that out thank you i have a follow-up question distillation part yeah of course okay so if you distill it back down now you have like one technically you're back to the dense layer architecture right um so now the entire like the entire idea of expert is that certain tokens would be sent to different experts because they just like i don't know are more specialized in figuring something out about this token um so now if you go back to this like dense layer like aren't you like basically only serving whatever like um whichever expert you base the stance later on like these tokens will probably perform well and all the other tokens are kind of like left behind right i'm actually sorry i don't think i'm fully understanding your question so that so are you kind of getting at like we're displaying this on a specific data set so that i'm thinking about how to use that like why yeah yeah yeah so maybe concretely like let's so like for super glue right let's say you want to serve a model that does super glue well i think the idea is that like you distill the sparse model into a dense model on super glue so then you kind of get this compressed dense model that now performs better than if you were to just you know train it from scratch or train it from like a pre-trained dense model so then it's like did you use them say that again you you have to pick one expert right no no you can just distill all of the cause you're just matching the the model outputs so you can just treat the sparse model as kind of like a black box thing all we're doing is just trying to have the dense model match the actual like final like you know token predictions oh god okay you got it okay sorry i was not i was not familiar with the idea of the distillation so i think that was like my entire confusion okay thanks yeah of course yeah um because i i guess one motivation here is that um having experts can make serving a little bit more difficult because um it requires bigger topologies let's say you have eight experts um you need like well i guess you can have multiple experts on on fewer calls but um you know legislated a little bit harder to solve and so if we can you know get the benefits from sparsi at pre-training then use distillation to a dense model for serving uh that can be that can be beneficial so i think that was sort of the motivation for that experiment right yeah exactly uh okay where were we yeah yes i kind of just wrapping okay go ahead no go ahead ben i just said i think uh one more string kind of question so yeah so yeah go ahead i feel here to ask enough oh yeah yeah sounds good um yeah thank you guys for the talk so far uh just a quick question was wondering if you think there are any interesting directions around uh building models that are like explicitly optimized for for parallel training um i guess like the the moe model seems like you know it does a really good job here and also like at inference time it's you know very useful to like you know have fewer uh flops per per computation um but for forward pass but um i i guess do you think there are any interesting directions around distributed training where you might have like models that are explicitly are architected to have a lot of uh parallel heads or or other like features that are you know kind of embarrassingly parallelizable or does just using like standard you know scale up the models by adding more layers and then just you know get away with using model and data parallelism work well enough yeah so i think so yeah so let me just make sure i'm fully understanding so yeah i think also like you know right now like even our models are definitely very co-designed with the hardware and like the shapes and things you know um so yeah i i think at a high level like yes i think there's a ton of interesting research on like co-designing the hardware the partitioning algorithms and the models i think given you know that we have this kind of like spmd mesh style partitioning we are already kind of designing our models in ways that fit it really well so for example when we want to scale up our model one of the first dimensions we go to scale up is the internal hidden dimension because there's some really nice properties of scaling up this dimension it basically becomes like kind of you know independent to some of the communication costs it's really good when looking at the compute to memory operations on these you know like uh compute devices and stuff yeah exactly like i think when we're even designing these models we're like really setting dimensions such that it maps well onto hardware um so it's almost like you know given that we have this model data parallelism we're like actually designing models more for it but i also think that there's a ton of new interesting distributed algorithms and stuff like that which makes designing models very interesting like i think one thing that i think is really cool is like the microsoft zero partitioning too which also like adds some really new like nice implications for like how to design and scale models and stuff so yeah i think there's like this is a very fruitful research direction um if that kind of answered your question yeah no that was super helpful and interesting thanks yeah yeah definitely like i'm very optimistic on the future of us like designing the hardware the model the partitioning strategies all together because really to get it to work well you kind of have to know about all three and like kind of you know intertwine the development of them yeah yeah that sounds awesome yeah so just to summarize it's like yeah so switch transformer is like a nice simplification over a mixture of experts and we're seeing that we get really strong speed up improvements on pre-training over like a lot of the t5 models which are very strong baselines we're seeing that we can you know efficiently distill the sparse models back to dense ones and you know get improved both pre-training and fine-tuning through some of these newer techniques we talked about and we're also seeing that the models are working on multilingual data and that we can you know now easily successfully train up to you know 1.6 trillion parameter models which is pretty promising and um next slide and so we also wanted to go into two slides about some like newer work about actually using these kind of models for computer vision and actually also a little bit of how they can be used to actually do some level of like adaptive computation where not only now each input gets different weights but also sometimes different inputs will have different amounts of compute applied to it and yeah so there's some really great work of doing this out of the the google zurich team and yeah they're just doing it for image classification and you know they're basically seeing a lot of the similar types of scaling properties where you know scaling up the number of experts and using sparsity allows them to get good um performances on image classification um next slide and interestingly one of the things they do is like as we talk about the capacity factor so we were talking about values of like 1 1.25 2.0 which means like at a value of 2.0 there's buffer for you know two tokens per expert but they actually study it going less than one so that means that like at 0.5 that means there's only like room for half the number of tokens and the nice part is is that they did this for image classification and also in images there's just a lot of redundancy and they notice that you can actually get really good performance by only allowing like you know up to one tenth of the the parts of the image to be processed by a sparse layer so yeah we think this is like a really nice direction too in terms of combining sparsity along with like adaptive computation and yeah and uh yeah thanks so much for having us that's that's the talk so thank you uh barrett and um sorry are you fun for coming here uh so you know uh so i will just like ask a bunch of questions and then we can have like uh after the class uh open question panel for the students uh so one thing is like have you tried using like like more like linear attention mechanisms like reformers and like other stuff uh to like scale the computation um um i personally have maybe i haven't personally done this yes so oh you know i guess we can maybe comment on how um you know the attention the cost coming from the attention maps isn't the dominant costs in in this large transformers um so you know the motivation for using linear attention um like performers is that it reduces uh the quadratic cost of attention that's right um but so far i mean at least you know in like sort of typical nlp setups like super glue c4 and so on as you scale the models most of the memory comes from the the model weights as opposed to attention to the attention maps that's also because you know using very long context or sequence length doesn't prove that fruitful and so you know just uh you know working with the vanilla self-attention mechanism is uh is a very strong baseline already got it okay um so another question is like uh do you think this uh like mechanism is even more scalable like can you go on and build like 10 trillion parameter models stuff like that like what do you think yeah definitely i think yeah totally i think honestly the one of the biggest constraints is that like you know and this isn't even necessarily constrained it's just like you have to fit the parameter somewhere and there's just limited storage on devices but if you get enough devices such that you know yeah you can just partition the weights it's like yeah i don't see anything stopping it got it so what do you think personally is your like uh like the thing like with the direction like uh like scaling of transformers will go into like uh will be more like works that are trying to just like use transformer like mechanisms like minister of experts or do you think there's like you're going to be other things that the community needs yeah i mean i definitely think mixture of experts should find its way or at least you know sparse players like switch transformers will definitely i think find their way into like the future of large models i think they really confer a lot of benefits and they're also very good in like high throughput applications so i think the one thing like so the one downside is on sparsity is like if you look at the performance per model weight they're going to always be worse than bounce models so it's like if you really are constrained on like i want to design the best model i can to fit on as small of a device as i can then they're probably not going to be the best solution because the sparse weights just aren't as good as just the dense weight that's being used for everything so i think it really depends on the application but i'm very optimistic for when we're training these models during pre-training with lots of data parallelism and then we're serving them and like medium to higher throughput examples i feel like they could actually just be a pretty big win so that that's kind of my thoughts on on how i think sparsity will be used in terms of other things yeah i think i don't know there's a ton of exciting research you know from everything from yeah like a lot of the linear attention stuff adaptive computation new pre-training objectives you know yeah it's hard to know what the future will look like but uh yeah a lot of exciting things to look forward to it sounds good okay uh so we can now have like a round of student questions so we'll just stop the recording Back To Top