hi there today we're looking at pre-trained transformers as universal computation engines by kevin liu adita grover pieter abiel and igor mordach on a high level this paper argues that pre-trained transformers specifically transformers pre-trained on language modeling are doing something called universal computation and the way they prove it is by transfer learning these transformers to completely new domains so not language modeling they do things like xor tasks or c410 so computer vision they transfer learn these transformers to these completely new domains and they don't just do it in a regular transfer learning way they freeze almost all of the parameters of that transformers specifically they freeze all of the attention and all of the feet forward layers in the transformer therefore they only fine tune about .01 or so or 0.1 of the parameters of the model and they show that on these specific tasks these frozen pre-trained transformers as you can see right here are competitive if not outperforming a transformer that is fully trained from scratch on these tasks and it also mostly outperforms lstms that are fully trained from scratch on these tasks so this is pretty interesting and it gives rise to a number of sort of questions about what happens in these transformers so we're going to look at what the claims are and what the let's the evidence brought forth by this paper is about why language pre-trained transformers are universal computation engines and yeah i'll have some comments on my own as always if you do like content like this share it out leave a like and tell me what you think is going on here in the comments right so the abstract reads we investigate the capability of transformer pre-trained on natural language to generalize to other modalities with minimal fine-tuning okay and they say in particular without fine-tuning of the self-attention and feed-forward layers of the residual blocks so as you know or as you might know a transformer is built approximately like this so what you have is you have input so you have the positional embeddings and you have the input embeddings now if it is a language model that is simply one vector for every word or word piece if it is an image model like in the vision transformer in the vip it is you simply take the image and you make it into these patches and then each pack patch you simply unroll the patch into one long vector so you simply unroll the pixels and that is a patch and that in the sequence of such patches is your inputs now what follows is these self-attention blocks and this is the majority of the transformer is l times the self-attention blocks you always have a attention layer and if you if you don't know what an attention layer is i'm sure you'll find some video on youtube that explains it um this is followed by layer norm this is followed by a element-wise feed-forward layer and it is again followed by a layer norm you also have the residual connections as you can see right here and then all of this is followed by an output layer and the output layer is very task specific in language modeling it's obviously classifying into the vocabulary so into one of whatever the 30 000 possible continuations in computer vision it might be classifying into the classes of the data set so for example in imagenet you'd have a thousand classes or 21 000 depending on which which version you use so what they're saying is they are not fine-tuning they are freezing the multi-head attention and they're also freezing the feet forward layers now these make up like 99 some percent of the transformer so what they get is they get a frozen pre-trained transformers and frozen specifically refers to these parts i marked in blue in fact they just keep the attention and they keep the feed forward layers as they come out of the of the language pre-training and then they train the things on different tasks so these tasks are as follows there's bit memory they consider a bit memory task where the model is shown five bit strings each of length 1000 afterwards the model is shown a masked version of one of the bit strings where each bit is masked with probability 0.5 and a model is tasked with reproducing the original bit strings so you give it you give it five bit strings in sequence and then you give it a sixth one that is kind of corrupted and the model must figure out which one of these five it is and then it must successfully reproduce that bit string so if it figures out it's probably numbered so the model has to look at the overlap between the strings and then where there's the most overlap it needs to copy over that string or the non-overlapping parts so this is a fairly um complicated task for a model like this that is just trained with backprop right there is bitxor where you have two bit strings of length five and you need to compute the element wise xor this is a long-standing difficult task for neural networks we know that there is list ops where you get a sequence like this and you must compute the result so it's acting a little bit like a calculator so now it turns actually out that if you think of the bit memory that's already pretty similar to language right bitxor maybe not list ops we're going to see that these models perform fairly poorly on the list ops task and then the last one is computer vision so mnist and c410 is the classic like vision transformer uh domain where but still they take the transformer that's pre-trained on language and simply fine-tune the positional embeddings the input embeddings the output layer and the layer norm parameters that's all they do and the last one is c410 from the long range arena where instead of forming patches like this in the long range arena task you simply take every single uh pixel into as its own kind of so you don't do patches anymore you do you unroll pixel by pixel that is significantly longer uh vector for the model to to compute over so it's going to make the task a bit more difficult because you completely lose all localization information and the last one is this remote homology detection it's a task from protein folding okay so how do these how do these things do you've already seen this here in the overview namely if you train these things on these bit tasks so bit memory or bits or you can see that a if you the frozen transformer here reaches a hundred percent so does the full transformer so what that shows you it's not necessarily which one's better it's just that both are comp are able to completely solve this task well for example an lstm is not that we have no idea here what the size of the lstm is i don't think they stated anywhere so the comparison with an lstm it is cool to see that the lstm doesn't get this relatively simple task but it also might just be a function of how large the lstm is and how much rigor goes into training one nevertheless the lstm can't solve it and that's because the lstm takes in a sequence as just one at a time and it needs to sort of remember in its hidden state uh what the individual elements are and it can't go back right the transformer can always look back the lstm needs to remember everything and i think that makes it much harder to do these kind of sequence tasks i already told you list stops they all perform badly but interestingly they perform equally badly so the full transformer here is no better than the frozen transformer which is very interesting um if you look at mnist and c410 actually all of the other tasks you'll see that the frozen transformer is not worse than the full transformer in fact it's sometimes better and that is going to be an interesting um an interesting thing also to look at so the whole paper is actually just ablation studies into this phenomenon like why does this happen and it's very cool and um the result is going to be so the authors claim that there is something special about language pre-training that already primes the transformer to be receptive to these new tasks now there are two different possibilities if you if you think uh what's happening here actually let's first go to the ablations and do the discussion at the end because once you see what is happening um you'll have you'll be able to form your own opinion what i would like to remind you though of is that they do pre they do train these layer norm sorry they do train the layer norm parameters right so when i saw this when i when i saw this and they said well we only train the input embeddings because of course it's a different modality so adjusting the input embeddings makes sense right and the positional embeddings may be two and the output layer because we have a different task that makes sense too and the rest we freeze but we also adjust the layer norm parameters right but we don't adjust the attention my immediate thought was you probably probably tried doing it without the layer norm parameters at the beginning they probably tried just adjusting input and output embeddings and that probably didn't work too well and in the ablations you're actually going to see this so um and there i think this hinges on the fact and we've seen this with transformers before i think they're called adapter layers so if you have your kind of transformer layers one after another what you can do is you can build in these adapter layers that have very few parameters that are kind of compressing and uncompressing the the data and that's a way you can fine tune the transformer so this kind of goes in and out again in dimensionality that is a way you can adapt and we we know that these things are very possible with transformers that you can sort of have the transformer ready and then only adjust very few parameters to transfer learn and i think the same is going on here now what the the authors sort of hint at is that in in the schematically if you have the transformer you have the attention part which is sort of the cross information routing part right and then after that you have the feed forward part which is element wise like this and then you sort of have a layer norm part and the layer norm part what it essentially is in terms of learnable parameter is that you take one element here or even one channel or one layer and and this depends on the exact type of norm but you in the input signal you have two parameters that you learn so your output of the layer norm is going to be a normalized x so this is a normalization and you do it either over the bachelor over the layer or something like this in layer norm you do it over the layer and you have two parameters that you can learn one is a scaling and one is an offset and i think you know by learning these you can adapt and this is this is i think these two things have a lot of relation to each other even though the authors say we don't learn any of the attention i can by influencing this a and this b right here and this y then goes into the next layer of attention i can very much influence how the attention works right if the y is then in the next layer from the y i construct the w sorry i construct the the keys queries and values give of this particular element and that decides what information gets routed where and so on so i have very much an influence over the over the attention in the next layer by adjusting this a i might not have a direct influence like i can only if of course if i want to change something in an element in the key an effect of this because i have to change the y as a whole is going to be that also change something in here but certainly backprop will figure out some way i can make this happen okay so i i think this this whole notion of um we don't influence the attention at all it's not as clear-cut it's true they don't change the attention parameters however they are very they are a lot able to influence how information is routed by changing the signal itself in these layer norm parameters also they here they call it zero shot uh they say improves performance and compute deficiency on non-language downstream tests in particular we find that such pre-training enables the frozen pre-trans transformers to generalize in zero shot to these modalities zero shot i think that's a bit of an it's a bit of an over claim like i get it you you pre-train whatever how many few percent like um only fine tuning 0.1 of the total number of parameters of the transformer model and none of the self-attention parameters i don't think it's entirely fair to call this zero shot unless i completely overseen and misread the paper which of course is possible because i'm just one per person uh reading a paper okay so again we fine tune the output layer the input layer the layer norm parameters and the positional embeddings um my claim is this here does most of the work like we know we already know that for for example for cnns uh we can do a we can take a randomly initialized cnn and by just adjusting the batch norm parameters we can already gain a non-trivial uh result and i think the layer norm here is doing a lot of the work of course the input and output layer as well we also know that we can take like a randomly initialized neural network and simply training an output layer can already also give us a good performance this is all stuff they do in this paper however i think the layer norm does a a lot of the a lot of the crucial work here uh to but there are still some interesting things that come out of these experiments uh because it's not just that okay so as i said the paper is a big piece of ablation studies oh yeah that's what i forgot the interesting thing of course is that the fully trained transformer isn't better right that's the interesting thing like if you fully train a transformer on the same tasks and this is due i think and i think the paper agrees due to the fact that we are in sort of the low data regime at least for the things here that are like the natural data sets like mnist or c410 we don't have too many we don't have too many uh data points so training a big transformer with all the parameters could even be counterproductive because we're just going to over fit or shoot ourselves in the foot all right let's go through these experiments can pre-trained language models transfer to different modalities and the answer here is going to be yes absolutely so their base thing is like a gpt-2 model that is trained on language and it's so interesting right that if you transfer to these tasks and you can see uh right here you compare it the so these are the results from figure one this is just what you saw in the bar diagram again it's pretty interesting that these fully the frozen pre-trained transformers uh match the performance of the full and outperform the lstms on these tasks pretty cool so in some tasks you can see right here in the homology they even outperform the fully trained transformers the second one what is the importance of the pre-training modality so here they're going to compare what if we just randomly initialize the transformer and then keep just keep we freeze the same layers but they're not trained or randomly initialized or we pre-train it on this bit uh memory tasks this is just this one task or we pre-train it on imagenet imagenet21k in fact we so we pre-train instead of one language on images or we pre-train on languages this is this fpt is pre-trained on languages which one is going to be the best so this is to counter people they're making the claim that language modeling has a specific specific property that language is sort of a good task to pre-train these transformers better than other modalities so you can't just pre-train the transformer on any old task that's what they're saying here that language is somehow special or the best out of these ones so in order to demonstrate that you can see right here the this is the language one the randomly initialized one already kind of underperforms throughout here so actually not that much in these things here but you can see on mnist or on c410 it it does not perform too well all across the bit memory one obviously performs well in the bit memory task that's what it was pre-trained on but also it kind of sucks on the rest of these tasks it's okay in mnist it's the performance is kind of shaky and the vision transformer is better but it still lags behind except on c410 uh because you know being pre-trained as a vision model might you know it it seems like it's okay that it performs well on image modeling the whole point here though is to generalize two domains out of your free training thing and on these domains the language one is better than all the other ones now the question there is multiple questions here i think it is a bit too early from just this paper to say that language modeling has this special property right what i think might also be an explanation is for example how difficult is your pre-training task now when you look at language modeling you can look at simply how many classes does it have so the number of classes is in language modeling something like 30k like these vocabularies are fairly large random it's absolutely nothing these bit memory tasks is so you have two classes and in the vision transformer you have 21k classes but you only need to apply it once per sequence right you only have to have one output whereas in language modeling you need to output every single so every single token is a classification so in in fact the this is not necessarily more classes but it is let's say more training examples per training data point that you get because every token is a training example essentially so it might not be a language thing it might just be how how hard the task is in terms of number of classes and how much training data you have available i think there are a lot of variables that they haven't necessarily controlled for here and it might be a bit too early to say language modeling is the task though what i'm completely prepared to accept is to say language modeling is a good task in fact it's the best task out of these ones but i think the it could be a cool it could be cool to research more in this direction and say okay can we find a better task can we find a task that is even more complex and that depends on what is really going on here so i see two possibilities possibility one why this even works is to say that somehow natural signals are all somehow equal so pre-training on language somehow makes the transformer the attention layers just adjust themselves to the sort of natural signals that we see around us so when we feed in an image recognition task or any other task that kind of humans care about in the natural world the transformer is already sort of prepared about what that could entail like about the types of computation and then second of all uh and this this is different this is simply um with enough complexity you see there is simply uh what i'm going to say computational putational utility computational utility what i mean by that is that there are sim when when you pre-train on a task certain types of computation are going to be important for that task and the more complex and the bigger your model the more sort of prim computational primitives you can encode into the attention layers now when you encode these computational primitives it's not necessarily of course it has something to do with the type of signal but i think what's up what could be happening is that these transformers they simply they prepare a lot of good features that are just useful to compute different stuff like x4 like remembering things and so on i think this could definitely be the case that in these attention layers there are these just computational primitives encoded and if you pre-train on a task and the harder the task is the more of these primitives need to be encoded and what you do when you adjust the layers in between is simply that you recombine these primitives in a better way but sort of all of the computational primitives are already there i think i think the two are not necessarily even exclusive and i think the paper hints at both might be playing a role right here i don't think they say exactly the same thing but this would also give sort of meaning to this word of computation or universal computation engine they're of the these transformers and we might even extend that to probably any machine learning model if we could scale it up and train it correctly probably evolves or trains to have these computational primitives inside of it and that's why we can adjust it with just a little bit now they're going to claim um there is something about language pre-training later so first of all they say how important is the transformer architecture and here they simply say if we take a randomly initialized transformer and compare it with a randomly initialized lstm we freeze we freeze the attention layers and then we just do our frozen training then the transformer performs a lot better than the lstm here in most actually all of the tasks however this is a very shaky comparison of course because how do you fairly compare a transformer architectures within lstm architectures do you control number of parameters number of computation speed i don't know okay so i don't know what's fair next does language pre-training improve efficiency over random initialization the answer is yes it converges much faster if you pre-train with language and do the frozen attention layers attend to modality specific tokens so here they're just going to look at the first attention layer and they see that the attention matrix for example in this big sore task attends so here are the two here are the two this is string number one this is string number two and in the output from here you need to compute the the xor you can see that the attention first is it's on the on the first one and then it's also on the second one right in the output it always looks at the corresponding position so here you can see clearly that the attention matrix already attends to the correct things for the task which is cool because we've never trained the attention right but it's i think that goes into my claim that look um we are still able to influence the attention matrix even though we don't train the attention weights we are able to influence it by training these in between parameters the same goes for these bit memory tasks you can see the attention matrices are very much attuned to the task right here next one does freezing the transformer prevent overfitting or underfitting and here they uh they train this frozen transformer and they compare it to training a transformer that just has three layers so they say our general finding is that in contrast to their fully trained counterparts fpt models underfit the data which lends them to further improvements by increasing model capacity so if you compare it to a three layer transformer the three layer transformer does outperform the 12 layer frozen transformer however it does so by reaching a much higher training accuracy so overfitting is much more of a problem if you fully train the transformer however if you use this frozen transformer you're probably under fitting as you can see right here so you could technically scale up and gain more power with this frozen fine tuning does performance scale with model size yes so you can see as you increase from small to medium to large as you increase the number of layers the performance increases however the performance also increases for a randomly initialized one so it just seems to be like the more parameters the better it's the same and here is something i find interesting can performance be attributed simply to better statistics for initializations here they're going to let's say make the point that there is something about language model pre-training that actually makes the transformer conducive to all these tasks and you can't just reach that by better initialization which is more point one from here than point two because point two you could just reach by initializing in a better way like this we could we could characterize these computational primitives and we could build them in from the start whereas natural signals we can't characterize them otherwise we wouldn't need machine learning so what they're going to do is they're simply going to take a fully trained transformer which they call an oracle and then they they're going to compute the mean and the standard deviation so the the gaussian from those and then they're going to initialize this new transformer so they're going to take the pre-trained which they have they're going to do default which is the randomly initialized one we've already seen those one as well and then they're going to take a randomly initialized one but not randomly with a default randomization but randomly with the statistics they got from the oracle so this transformer is going to be randomly initialized but it has the same statistics as the uh as the full transformer or as a trained transformer so the statistics are correct and that does not seem it seems to help a little bit as you can see but it does not seem to help in fact here it even it even hurts however i think that's a bit of a weak experiment and i think there is still a possibility that we could initialize these transformers much better if we could if we could correctly capture the essence of these computational primitives that are there in that are learned by gradient descent i think if we can capture those in a theoretically sound way we might be able to initialize or if we could just yeah if we could find like a not a natural language but if we could find a synthetic pre-training task that is just so hard but it completely initializes all of these computational primitives that might still be better and that's going to be the ultimate experiment that differentiates between option one natural language pre-training is somehow important because of grammar and natural signals or option two what we're doing is just inputting computational primitives into these layers does fine-tuning self-attention and feed-forward layers further improve performance and the answer is actually no it degrades you can see right here this is worse than this and that's because probably of over fitting if you fine-tune the whole transformer you're going to fall down and now here is where it really comes in that you know these tasks they are in the low data regime i know if you go back five years that sounds ridiculous but right now they are these things will over fit if you train everything and here it comes which parameters of the model are important to fine tune and you can go look at the you can go look at the look at the table it's in the appendix but they say in particular we find orthogonal initialization wait we run ablations here we generally find the layer norm parameters to be most important the layer norm parameters all right and that sort of gives it gives a gives credence to the fact this is not so the i think what what they're doing yeah these layer norms they carry a lot of the weight of these things right here it's still pretty cool because they're very few parameters that you need to fine-tune and okay now they do a bunch of more ablations like only training the output layer which gives non-trivial performance but not a good enough performance so and yeah for some reason i have another set of the paper right here but this was essentially the paper it's very cool and the paper is super i think it's well written and it's easy to read because it's like hey here is a phenomenon we've discovered and now we're just going to investigate all kinds of things that explain this phenomenon we're going to rule out uh some stuff some hypotheses and we're going to arrive at some kind of conclusion in here and yeah that was my two cents to this paper i hope you enjoyed it it's a bit of a shorter video and bye bye Back To Top