Training a Language Model From Scratch
Trying to train a LLM LM from scratch
I did this more for the lulz and for the academic value than any practical value. I’d understood the math behind the transformers architecture. I had a decent understanding of the practical considerations. I knew how neural nets (MLPs) and convolutional neural nets were trained. I also knew Andrej Karpathy had a whole walkthrough video around this. How hard could it be?
The short answer
It depends on
- How deeply you want to understand the code - syntax, semantics, everything
- How many libraries you want to reuse and how well you know them
- How many mathematical improvements you want to build in
Results
Method | Time to implement | Customisable parameters | Comments |
---|---|---|---|
Using Huggingface transformers | 3 hours | Model architecture (layers, heads, embedding dim), tokenizer choice, training hyperparameters | Standard, most practical approach |
Using PyTorch/tensorflow | 3 days | All architecture components, training loop, optimizer, loss function, tokenizer | Best approach to learn under the hood workings |
Using Python operations | 5 days | Every mathematical operation, data pipeline, weight initialization, training schedule | Esoteric exercise without practical benefits |
Actual steps needed
The transformer architecture works something like this if you have a text input for which you want to continue generating further text.
- The text is first converted into ‘tokens’ which are numeric IDs assigned to each small subpart of your text. The subpart creation is mostly standardised (tiktoken) but more details below.
- These tokens are then converted into a vector of embedding_dimension which is your first model parameter. For GPT3-175B, this vector embedding dimension was 12,288. Embedding values for each token are learned during the model training process.
- These token embeddings, along with ‘positional embeddings’ that hold the data about the position of each token are added together and act as input to the transformer. There’s a new approach called ‘rotary embeddings’ which I don’t fully understand yet.
- The combined embeddings enter the attention block. The maximum number of tokens you want to train at once is decided here and becomes your context window.
- The attention block has 3 weight matrices (per head) - Query, Key, and Value - where the value matrix is further decomposed for faster computation. These matrices are multiplied with the input matrix to ‘enrich’ your input embeddings. This step is the heart of GPT architecture. See the ‘encoder and decoder’ note below for more details.
- Now that each token in your input knows more about its context in the block of text, each embedding is run piecewise through a feedforward layer (MLP architecture) which is typically 4x the size of your embedding dimension. The output is brought back into the embedding dimension which becomes the input for the next attention block.
- Once you’ve run your input through a combination of Attention-MLP blocks, the last vector in the output of the last MLP block is decoded. That holds the model’s prediction for the next token in the form of raw logits representing a probability distribution over the model’s entire token vocabulary. The temperature and top-k parameters influence sampling from this distribution.
- This process is iterated and model generates text token by token. It’s not an inherent restriction of the transformers architecture to generate one token at a time. This behaviour where the text is generated one token at a time is called autoregressive, hence the name autoregressive transformers.
There are some mathematical nuances here around standardisation, variance normalisation, and ways to simplify backpropagation which I’m leaving out.
Practical learnings
- All the concerns about neural network training remain valid - especially for smaller models. Overtraining, learning rate optimisation, batch size optimisation, epoch, number of steps… Hyperparameter tuning is real.
- The operations are basic math and can be implemented ‘manually’. It’s the sheer scale of these operations that necessitates use of highly efficient libraries, mathematical tricks, and hardware optimised code that can extract the last drop of speed out of your setup.
- As an individual or a small organisation, I only see academic merit in trying to train a base model. That said, the academic benefit is significant and everyone should try to train a simple model.
- The ‘mathematical nuances’ I’ve ignored have a huge implication for model performance. Things to worry about - that I did not worry about - include weight initialisation, learning rate warmup, layer normalisation, batching size & padding and so on.
The model itself
While it was never about the specifics, here’s how it went. More details on my Github.
Model parameters
block_size = 256
n_layer = 2
n_head = 4
n_embd = 128
dropout = 0.0
bias = False
learning_rate = 6e-4
max_iters = 6000
weight_decay = 1e-1
The input dataset was ‘Harry Potter and Methods of Rationality’ by Less Wrong.
In 2 hours of training on my 8GB Macbook Air M2, the loss saturated at about 1.26 on the training data.
Here’s what the output model has to say when I prompt it with “This LM I trained from scratch thinks that my website is “.
“This LM I trained from scratch thinks that my website is facting into dange to spoke were that stop an excusion and possible where a spolited, and when something a coming laught away.”
Laught away indeed.
More details if you’re interested
Choice of tokeniser
There are pre-existing tokenisation models. Open source ones too. You could create your own tokeniser which sometimes has practical advantages. In niche domains, you want your model to understand the jargon. For example in chemistry, you might want to look at ‘nitrous’ as one token each instead of breaking it as ‘nitr’, ‘ous’.
Heads of attention and parallelisation
The attention block I described above is a ‘single head of attention’. In every block, many such QKV matrices are trained in parallel, each with an attention head of its own, leading to the name ‘multi-headed attention’. The dimension of the attention head and number of attention heads are related to the dimension of vector embedding.
The popularity and success of attention is in large part due to the fact that this entire process consists of matrix multiplication which is highly efficient, highly parallelisable and could thus be easily scaled. And scaling laws, as we know, held for a long long time.
Encoders, Decoders, and Causal LMs
There’s one ‘small’ decision in design of your attention head. Whether to mask your Q matrix in such a way that a token can only learn from past tokens. If you mask it, this becomes a decoder block. Otherwise, it becomes an encoder block. This is about 2 lines of code but has significant implications.
If you choose to use a decoder block, the type of LM is called a causal LM.
Cross attention
In use cases like translation or text-to-speech where you have to look at two streams of data in tandem, the attention block is built such that the enrichment happens across the two sets of vector embeddings. The Q matrix of one data stream works with the K matrix of another. These are almost always encoder type models.