Mathematical Foundations of Large Language Models: Training Objectives, Token Probability, and Loss Functions
Raj Shaikh 2 min read 233 words1. Training Objectives: The Rules of the Language Game 🎯
Training objectives are the guiding principles that help LLMs learn how to handle text. Think of them as the rules in a game of charades—but instead of acting, the model predicts words. There are two main flavors:
1.1. Masked Language Models (MLMs): The Fill-in-the-Blanks Expert 📝
Masked Language Models, like BERT, are trained to predict missing words (or tokens) in a sentence. It’s like a game of Mad Libs, where the model sees:
- “I ___ AI.”
- And guesses: “love” (hopefully).
Objective for MLMs: Maximize the probability of the correct token (\( w_i \)) at a masked position:
\[ \mathcal{L}_{\text{MLM}} = - \sum_{i \in \text{masked}} \log P(w_i | w_{\text{context}}) \]Where:
- \( w_{\text{context}} \): Unmasked tokens in the input.
- \( P(w_i | w_{\text{context}}) \): Model’s predicted probability for the missing word.
Example:
Input: “The cat sat on the ___.”
Mask: \("mat"\)
The model predicts probabilities for possible tokens:
- “mat”: 0.8 (good job!)
- “dog”: 0.1
- “carpet”: 0.1
The loss penalizes the model for assigning lower probabilities to the correct word.
1.2. Causal Language Models (CLMs): The Word Fortunetellers 🔮
Causal Language Models, like GPT, are trained to predict the next word in a sequence. The twist? They only see the words before the target (causality matters, folks).
Objective for CLMs: Maximize the probability of the next token:
\[ \mathcal{L}_{\text{CLM}} = - \sum_{t=1}^{T} \log P(w_t | w_{- \( w_{
- \( T \): Total length of the sequence.
Example:
Input: “Once upon a time, there was a ___.”
The model predicts:
- “princess”: 0.9 (nailed it!)
- “dragon”: 0.05
- “robot”: 0.05
The training process tweaks the model to assign higher probabilities to the right answer.
Key Differences Between MLMs and CLMs
Feature | MLMs | CLMs |
---|---|---|
Direction | Bidirectional (sees both sides) | Unidirectional (causal) |
Goal | Predict masked tokens | Predict next token |
Example Models | BERT, RoBERTa | GPT, GPT-3 |
Numerical Example: Causal Language Model
Imagine training a causal language model on the sentence: “AI is amazing.”
Step-by-Step:
- Tokenize the Input:
- Tokens: [“AI”, “is”, “amazing”].
- Predictions:
- \( P(\text{is} | \text{AI}) = 0.9 \)
- \( P(\text{amazing} | \text{AI is}) = 0.8 \)
- Loss: Use the negative log probabilities: \[ \mathcal{L}_{\text{CLM}} = - \left[ \log(0.9) + \log(0.8) \right] \] Approximate: \[ \mathcal{L}_{\text{CLM}} = - (-0.105 + -0.223) = 0.328 \]
The smaller the loss, the better the model is at predicting!
Fun Analogy
Training objectives are like exam prep:
- Masked Language Models (MLMs): A fill-in-the-blank test. The model studies the whole sentence, then guesses the missing piece.
- Causal Language Models (CLMs): A guessing game where the model predicts what comes next, like completing a story.
Mermaid.js Diagram: MLM vs. CLM Workflow
graph TD Input[Input Sentence] --> MLM[Masked Language Model] Input --> CLM[Causal Language Model] MLM --> MaskedTokens[Predict Missing Tokens] CLM --> NextToken[Predict Next Token] MaskedTokens --> MLM_Loss[Calculate MLM Loss] NextToken --> CLM_Loss[Calculate CLM Loss]
2. Token Probability Computation: Math Behind the Magic 🧮✨
How Token Probabilities Are Computed
The journey to token probabilities involves these steps:
- Input Embedding: Convert tokens into vectors.
- Contextual Understanding: Apply self-attention to capture relationships between tokens.
- Logits Calculation: Use linear layers to get raw scores (logits) for each token.
- Softmax: Transform logits into probabilities.
Let’s break it down! 🛠️
1. Input Embedding: Words Become Numbers
Tokens are converted into dense vectors using an embedding matrix \( E \):
\[ \mathbf{h}_0 = E \cdot \mathbf{x} \]Where:
- \( \mathbf{x} \): One-hot encoded token (e.g., “AI”).
- \( E \): Embedding matrix of size \( V \times d \) (\( V \): vocab size, \( d \): embedding dimension).
2. Contextual Understanding: Attention to Details
The self-attention mechanism learns relationships between tokens. For each token:
\[ \text{Attention Score} = \text{softmax}\left(\frac{Q \cdot K^T}{\sqrt{d_k}}\right) \cdot V \]Where:
- \( Q \): Query matrix.
- \( K \): Key matrix.
- \( V \): Value matrix.
The attention mechanism ensures the model knows which tokens to focus on. For example:
- In “AI is amazing,” “is” relates most to “AI.”
3. Logits Calculation: Predict Raw Scores
Once contextualized, the hidden state \( \mathbf{h}_T \) (for the last token) is passed through a linear layer:
\[ \text{logits} = W \cdot \mathbf{h}_T + b \]Where:
- \( W \): Weight matrix.
- \( b \): Bias vector.
- Logits are raw, unnormalized scores for all tokens in the vocabulary.
4. Softmax: Turning Logits into Probabilities
Softmax normalizes logits into probabilities:
\[ P(w_t | w_{- \( P(w_t | w_{
- The denominator sums over all vocab tokens to ensure probabilities add up to 1.
Numerical Example: Token Probability Computation
Suppose we’re predicting the next token in “AI is ___.”
-
Embedding:
- Token “is” → \(\mathbf{h}_0 = [0.5, 0.2]\).
-
Attention Scores:
- Attention focuses on “AI” and “is.”
-
Logits:
- Raw scores from the final layer: \( \text{logits} = [2.0, 1.5, 0.5] \) (for tokens: “amazing,” “boring,” “fun”).
-
Softmax:
\[ P(\text{amazing}) = \frac{\exp(2.0)}{\exp(2.0) + \exp(1.5) + \exp(0.5)} \]Approximate:
\[ P(\text{amazing}) = \frac{7.39}{7.39 + 4.48 + 1.65} \approx 0.54 \]
“AI is amazing” gets the highest probability! 🎉
Why Token Probability Matters
- Next-Word Prediction:
- LLMs generate text by sampling from these probabilities.
- Training Signal:
- Probabilities are compared to the ground truth during training.
- Interpretability:
- Helps explain why the model prefers certain words.
Code Example: Token Probability Computation
Here’s how to compute token probabilities with PyTorch:
import torch
import torch.nn as nn
# Example vocab size and embedding dim
vocab_size = 3
embedding_dim = 2
# Embedding layer
embedding = nn.Embedding(vocab_size, embedding_dim)
tokens = torch.tensor([0, 1]) # Example tokens
# Forward pass
embedded = embedding(tokens)
# Attention (simplified)
Q = K = V = embedded
attention_scores = torch.softmax(torch.matmul(Q, K.T) / (embedding_dim ** 0.5), dim=-1)
context = torch.matmul(attention_scores, V)
# Logits and softmax
linear = nn.Linear(embedding_dim, vocab_size)
logits = linear(context[-1]) # Take the last token's context
probs = torch.softmax(logits, dim=-1)
print("Logits:", logits)
print("Probabilities:", probs)
Fun Analogy
Token probability computation is like ordering pizza 🍕:
- Embedding: You describe what you want (ingredients).
- Attention: The chef remembers your preferences (“extra cheese!”).
- Logits: Raw scores for available pizzas (“Margherita: 8, Pepperoni: 5”).
- Softmax: The probabilities decide your final choice (“Margherita it is!”).
Mermaid.js Diagram: Token Probability Flow
graph TD Token[Input Token] --> Embedding[Embedding Layer: Convert to Vectors] Embedding --> Attention[Self-Attention: Contextualize Tokens] Attention --> Logits[Linear Layer: Compute Logits] Logits --> Softmax[Apply Softmax] Softmax --> Probability[Token Probabilities]
3. Loss Functions in LLMs: The AI Report Card 📚✨
What Are Loss Functions?
A loss function measures the difference between the model’s predictions and the ground truth. Smaller loss = better model.
In LLMs, the loss is computed at every token prediction step. For example:
- Input: “AI is amazing.”
- Prediction: “AI is boring.”
- Loss: “Nope, let’s fix that!” 🚫
Types of Loss Functions in LLMs
1. Cross-Entropy Loss: The Gold Standard 🌟
The most common loss function for LLMs is Cross-Entropy Loss, which measures how far the predicted probability distribution is from the true one.
The Math Behind Cross-Entropy
For a single token \( w_t \), the cross-entropy loss is:
\[ \mathcal{L}_{\text{token}} = -\log P(w_t | w_{- \( P(w_t | w_{
- \( T \): Length of the sequence.
Example: For the sentence “AI is amazing,” let’s say the model predicts:
- \( P(\text{is} | \text{AI}) = 0.9 \)
- \( P(\text{amazing} | \text{AI is}) = 0.8 \)
The loss:
\[ \mathcal{L}_{\text{sequence}} = - \left( \log(0.9) + \log(0.8) \right) \]Approximate:
\[ \mathcal{L}_{\text{sequence}} = -(-0.105 - 0.223) = 0.328 \]2. Label Smoothing: Making It Less Harsh 🤗
To avoid overconfidence in predictions, we can use label smoothing, which adjusts the true distribution slightly:
\[ P_{\text{true}}(w) = (1 - \epsilon) \cdot 1_{\text{correct}} + \frac{\epsilon}{V} \]Where:
- \( \epsilon \): Smoothing factor.
- \( V \): Vocabulary size.
This encourages the model to spread probabilities slightly, reducing overfitting.
3. Negative Log-Likelihood (NLL): A Close Cousin
Cross-Entropy Loss is essentially Negative Log-Likelihood (NLL) for categorical distributions. It works the same way:
\[ \mathcal{L}_{\text{NLL}} = - \sum_{t=1}^T \log P(w_t | w_{Why Loss Functions Matter
- Feedback Mechanism:
- Loss functions guide the model to improve over time.
- Training Stability:
- Well-designed loss functions prevent issues like vanishing gradients.
- Performance Metric:
- Lower loss correlates with better predictions.
Numerical Example: Loss Calculation
Let’s compute the loss for a simple sentence.
Input Sentence:
“AI rocks.”
Token Probabilities:
- \( P(\text{rocks} | \text{AI}) = 0.7 \)
- \( P(\text{} | \text{AI rocks}) = 0.6 \)
Cross-Entropy Loss:
\[ \mathcal{L} = -\left[\log(0.7) + \log(0.6)\right] \]Approximate:
\[ \mathcal{L} = -(-0.357 - 0.511) = 0.868 \]Code Example: Cross-Entropy Loss in PyTorch
Here’s how to compute Cross-Entropy Loss:
import torch
import torch.nn as nn
# Example token predictions (logits) and true labels
logits = torch.tensor([[2.0, 1.0, 0.1], [0.5, 2.5, 0.2]]) # Shape: (batch_size=2, vocab_size=3)
true_labels = torch.tensor([0, 1]) # Correct tokens for each example
# Cross-Entropy Loss
loss_fn = nn.CrossEntropyLoss()
loss = loss_fn(logits, true_labels)
print("Cross-Entropy Loss:", loss.item())
Fun Analogy
Loss functions are like grades in school 📚:
- If you ace the test (high probability for the correct word), your loss is low.
- If you flunk (assign high probability to the wrong word), your loss skyrockets. Over time, the model learns to cheat less and study harder! 🎓
Mermaid.js Diagram: Loss Computation in LLMs
graph TD Input[Input Tokens] --> Embedding[Embedding Layer] Embedding --> Attention[Self-Attention: Contextualize Tokens] Attention --> Logits[Linear Layer: Compute Logits] Logits --> Softmax[Softmax: Convert Logits to Probabilities] Softmax --> LossFunction[Cross-Entropy Loss] LossFunction --> Backprop[Backpropagation to Update Model]