• The Decoding
  • Posts
  • How To Shrink Language Models 25x Without Losing Performance

How To Shrink Language Models 25x Without Losing Performance

The Power of Retrieval-Enhanced Networks

How To Shrink Language Models 25x Without Losing Performance

The Power Of Retrieval-Enhanced Networks

DeepMind Scales Retrieval-Enhanced Networks

Today’s language models are too big!

Creating predictions with GPT-3 will cost you an arm and a leg. With all bells and whistles to make inference more efficient, you will need at least eleven V100 GPUs for $9000$ each for meaningful amounts of inference.

Hence, a computer that would allow you to make predictions with such a model, costs you more than $100K. Training such a model is orders of magnitude more expensive.

If you are a university or a startup that is a lot of money. If you are like me - a normal guy with sweatpants and a computer - you are out of luck.

But there is good news!

You can integrate a massive knowledge base into a large language model. (LLM). This can maintain or even improve performance with up to 25x smaller networks.

This does not only help financially but has the potential to solve some of machine learning's fundamental flaws.

Let’s go!

Money Is Not The Only Problem?

LLMs are not just exorbitantly expensive, they also have a tendency to hallucinate [7].

You could ask an LLM e. g. to describe what Barack Obama is doing right now. It might confidently present you with a totally false answer.

Behavior like this can be a real dealbreaker when trying to apply LLMs in practice.

There is another interesting challenge, which is actually not the model’s fault.

If an LLM is trained, it stores facts from the training data in its weights. But what if the training data becomes outdated?

Think of age-related information. The tennis player Rafael Nadal for example is currently 36. Naturally, he will only be 36 for one year. This is known as the temporal generalization problem [1].

In principle, you could retrain a model like GPT-3 every day. It would be super hard and prohibitively expensive, but it could be done.

However, this will never solve all problems. Many things like the weather in Tokyo can change multiple times a day.

Hence, there is a natural limit to how timely a neural network’s information can be.

Okay. So we can’t retrain GPT every second. What would an alternative look like?

Retrieval-Enhanced Neural Networks

The idea is to allow the model access to all information in a giant database.

This effectively turns the model’s job into an open-book exam.

The concept is not new. People have been integrating knowledge bases into language models for a while [2, 6]. The approaches differ, but the main idea looks always something like this:

  1. The model is prompted with an input text. E. g. “How old was Rafael Nadal when he first won the Australian Open?”

  2. Before the model processes the prompt, another part of the system retrieves relevant documents from a giant heap of data. The subsystem comes back with snippets like the following: “Winning the Australian Open at 23 years of age is a dream come true for the young Rafael Nadal…”

  3. These identified chunks of text are then fed to the model alongside the original input

  4. The model uses the combination of input and additional information to generate an output (Here obviously “23”.)

The example is totally made up, but you get the point. The model is allowed to peak into a giant database and use that to make predictions.

Let’s look at how well this actually works. (Spoiler: It is super impressive!)

DeepMind’s Retrieval-Enhanced Transformer (RETRO)

In a recent paper from DeepMind, this approach is scaled to an unprecedented size. Their knowledge base consisted of 2 trillion tokens. This is an order of magnitude more data than any current language model had access to during training.

To handle the incredible size, they build up a giant key-value store, of their knowledge base. The values are text chunks. The keys are pre-computed BERT embeddings of those chunks.

When their model performs a prediction, the input sequence is first split into several parts. For each part, they perform neural information retrieval to find similar text chunks in their knowledge base.

The retrieved text is then encoded with another transformer. The resulting embeddings are fed to the main model. In the figure below, you can see that these embeddings are integrated deep inside the network.

They are able to perform this massive information retrieval quickly because they pre-compute the encodings in the knowledge base. Then they use the SCaN library [2], which allows them to retrieve k-nearest neighbors within a few milliseconds.

So far so good. They condition the output on retrieved text but how well does It work?

Great Performance At Lower Cost

The researchers find that their method outperforms models such Jurassic-1 (178B) [3] and Gopher (280B) [5] on the majority of benchmarks.

That alone is great. When you consider that their RETRO model has only 7B parameters this is a very impressive feat.

Though 7B parameters is still a lot, it is 25x fewer parameters than GPT-3 (175B) [5].

Further, enhancing a model with a giant knowledge base does not only enable great performance with smaller models. Using retrieved documents to steer the model toward the correct answer has also been shown to reduce the tendency to hallucinate [6].

That is fantastic! But there is more!

Learning To Fly Helicopters In The Matrix

In an iconic scene from the movie Matrix, one of the characters downloads the ability to fly a helicopter directly to her brain.

Though DeepMind’s RETRO model probably can’t keep up with the almost infinite swag of the leather-wearing heroes in the matrix, it can do basically the same thing.

By extending or even completely swapping out the knowledge base, the model can be adapted to new domains without retraining.

The researchers at DeepMind however, did finetune the neighbor encoder and the cross-attention layers. This allowed them to get away with training only about 10% of all weights in the model.

The authors call this RETRO-fitting (😂). It works so well that they reach almost the same performance as when they trained the model from scratch on the new tasks [5].

Pretty impressive, no?!

Some people are probably thinking:

This is a crutch! Neural networks are universal function approximators. All we need is good data and massive end-to-end training.

Though I am sympathetic to this view, I think it is a fascinating line of research. It has the potential to solve some of the fundamental problems of machine learning such as temporal generalization.

Also, tons of applications require data that will not be available during foundation model training.

Whatever will come from this, I look forward to the future. Such exciting times to be alive!

As always, I really enjoyed making this for you and I sincerely hope you found it useful!

Thank you for reading!

References:

[1] A. Lazaridou, et. al, Mind the Gap: Assessing Temporal Generalization in Neural Language Models (2021), arXiv preprint arXiv:2102.01951

[2] S., Borgeaud, et al., Improving language models by retrieving from trillions of tokens (2022), International conference on machine learning

[3] R. Guo, et. al., Accelerating large-scale inference with anisotropic vector quantization (2020), International Conference on Machine Learning

[4] O. Lieber, et. al., Jurassic-1: Technical details and evaluation (2021). White Paper. AI21 Labs

[6] K. Shuster, et. al., Retrieval augmentation reduces hallucination in conversation (2021). arXiv preprint arXiv:2104.07567

[7] S. Roller, Recipes for building an open-domain chatbot (2020). arXiv preprint arXiv:2004.13637.