Mathematical Foundations of Large Language Models: Training Objectives, Token Probability, and Loss Functions



Raj Shaikh    2 min read    233 words

1. Training Objectives: The Rules of the Language Game 🎯

Training objectives are the guiding principles that help LLMs learn how to handle text. Think of them as the rules in a game of charades—but instead of acting, the model predicts words. There are two main flavors:

1.1. Masked Language Models (MLMs): The Fill-in-the-Blanks Expert 📝

Masked Language Models, like BERT, are trained to predict missing words (or tokens) in a sentence. It’s like a game of Mad Libs, where the model sees:

  • “I ___ AI.”
  • And guesses: “love” (hopefully).

Objective for MLMs: Maximize the probability of the correct token (\( w_i \)) at a masked position:

\[ \mathcal{L}_{\text{MLM}} = - \sum_{i \in \text{masked}} \log P(w_i | w_{\text{context}}) \]

Where:

  • \( w_{\text{context}} \): Unmasked tokens in the input.
  • \( P(w_i | w_{\text{context}}) \): Model’s predicted probability for the missing word.

Example: Input: “The cat sat on the ___.”
Mask: \("mat"\)
The model predicts probabilities for possible tokens:

  • “mat”: 0.8 (good job!)
  • “dog”: 0.1
  • “carpet”: 0.1

The loss penalizes the model for assigning lower probabilities to the correct word.


1.2. Causal Language Models (CLMs): The Word Fortunetellers 🔮

Causal Language Models, like GPT, are trained to predict the next word in a sequence. The twist? They only see the words before the target (causality matters, folks).


Objective for CLMs: Maximize the probability of the next token:

\[ \mathcal{L}_{\text{CLM}} = - \sum_{t=1}^{T} \log P(w_t | w_{Where:

  • \( w_{
  • \( T \): Total length of the sequence.

Example: Input: “Once upon a time, there was a ___.”
The model predicts:

  • “princess”: 0.9 (nailed it!)
  • “dragon”: 0.05
  • “robot”: 0.05

The training process tweaks the model to assign higher probabilities to the right answer.


Key Differences Between MLMs and CLMs

Feature MLMs CLMs
Direction Bidirectional (sees both sides) Unidirectional (causal)
Goal Predict masked tokens Predict next token
Example Models BERT, RoBERTa GPT, GPT-3

Numerical Example: Causal Language Model

Imagine training a causal language model on the sentence: “AI is amazing.”

Step-by-Step:

  1. Tokenize the Input:
    • Tokens: [“AI”, “is”, “amazing”].
  2. Predictions:
    • \( P(\text{is} | \text{AI}) = 0.9 \)
    • \( P(\text{amazing} | \text{AI is}) = 0.8 \)
  3. Loss: Use the negative log probabilities: \[ \mathcal{L}_{\text{CLM}} = - \left[ \log(0.9) + \log(0.8) \right] \] Approximate: \[ \mathcal{L}_{\text{CLM}} = - (-0.105 + -0.223) = 0.328 \]

The smaller the loss, the better the model is at predicting!


Fun Analogy

Training objectives are like exam prep:

  • Masked Language Models (MLMs): A fill-in-the-blank test. The model studies the whole sentence, then guesses the missing piece.
  • Causal Language Models (CLMs): A guessing game where the model predicts what comes next, like completing a story.

Mermaid.js Diagram: MLM vs. CLM Workflow

graph TD
    Input[Input Sentence] --> MLM[Masked Language Model]
    Input --> CLM[Causal Language Model]
    MLM --> MaskedTokens[Predict Missing Tokens]
    CLM --> NextToken[Predict Next Token]
    MaskedTokens --> MLM_Loss[Calculate MLM Loss]
    NextToken --> CLM_Loss[Calculate CLM Loss]

2. Token Probability Computation: Math Behind the Magic 🧮✨

How Token Probabilities Are Computed

The journey to token probabilities involves these steps:

  1. Input Embedding: Convert tokens into vectors.
  2. Contextual Understanding: Apply self-attention to capture relationships between tokens.
  3. Logits Calculation: Use linear layers to get raw scores (logits) for each token.
  4. Softmax: Transform logits into probabilities.

Let’s break it down! 🛠️


1. Input Embedding: Words Become Numbers

Tokens are converted into dense vectors using an embedding matrix \( E \):

\[ \mathbf{h}_0 = E \cdot \mathbf{x} \]

Where:

  • \( \mathbf{x} \): One-hot encoded token (e.g., “AI”).
  • \( E \): Embedding matrix of size \( V \times d \) (\( V \): vocab size, \( d \): embedding dimension).

2. Contextual Understanding: Attention to Details

The self-attention mechanism learns relationships between tokens. For each token:

\[ \text{Attention Score} = \text{softmax}\left(\frac{Q \cdot K^T}{\sqrt{d_k}}\right) \cdot V \]

Where:

  • \( Q \): Query matrix.
  • \( K \): Key matrix.
  • \( V \): Value matrix.

The attention mechanism ensures the model knows which tokens to focus on. For example:

  • In “AI is amazing,” “is” relates most to “AI.”

3. Logits Calculation: Predict Raw Scores

Once contextualized, the hidden state \( \mathbf{h}_T \) (for the last token) is passed through a linear layer:

\[ \text{logits} = W \cdot \mathbf{h}_T + b \]

Where:

  • \( W \): Weight matrix.
  • \( b \): Bias vector.
  • Logits are raw, unnormalized scores for all tokens in the vocabulary.

4. Softmax: Turning Logits into Probabilities

Softmax normalizes logits into probabilities:

\[ P(w_t | w_{Here:

  • \( P(w_t | w_{
  • The denominator sums over all vocab tokens to ensure probabilities add up to 1.

Numerical Example: Token Probability Computation

Suppose we’re predicting the next token in “AI is ___.”

  1. Embedding:

    • Token “is” → \(\mathbf{h}_0 = [0.5, 0.2]\).
  2. Attention Scores:

    • Attention focuses on “AI” and “is.”
  3. Logits:

    • Raw scores from the final layer: \( \text{logits} = [2.0, 1.5, 0.5] \) (for tokens: “amazing,” “boring,” “fun”).
  4. Softmax:

    \[ P(\text{amazing}) = \frac{\exp(2.0)}{\exp(2.0) + \exp(1.5) + \exp(0.5)} \]

    Approximate:

    \[ P(\text{amazing}) = \frac{7.39}{7.39 + 4.48 + 1.65} \approx 0.54 \]

“AI is amazing” gets the highest probability! 🎉


Why Token Probability Matters

  1. Next-Word Prediction:
    • LLMs generate text by sampling from these probabilities.
  2. Training Signal:
    • Probabilities are compared to the ground truth during training.
  3. Interpretability:
    • Helps explain why the model prefers certain words.

Code Example: Token Probability Computation

Here’s how to compute token probabilities with PyTorch:

import torch
import torch.nn as nn

# Example vocab size and embedding dim
vocab_size = 3
embedding_dim = 2

# Embedding layer
embedding = nn.Embedding(vocab_size, embedding_dim)
tokens = torch.tensor([0, 1])  # Example tokens

# Forward pass
embedded = embedding(tokens)

# Attention (simplified)
Q = K = V = embedded
attention_scores = torch.softmax(torch.matmul(Q, K.T) / (embedding_dim ** 0.5), dim=-1)
context = torch.matmul(attention_scores, V)

# Logits and softmax
linear = nn.Linear(embedding_dim, vocab_size)
logits = linear(context[-1])  # Take the last token's context
probs = torch.softmax(logits, dim=-1)

print("Logits:", logits)
print("Probabilities:", probs)

Fun Analogy

Token probability computation is like ordering pizza 🍕:

  1. Embedding: You describe what you want (ingredients).
  2. Attention: The chef remembers your preferences (“extra cheese!”).
  3. Logits: Raw scores for available pizzas (“Margherita: 8, Pepperoni: 5”).
  4. Softmax: The probabilities decide your final choice (“Margherita it is!”).

Mermaid.js Diagram: Token Probability Flow

graph TD
    Token[Input Token] --> Embedding[Embedding Layer: Convert to Vectors]
    Embedding --> Attention[Self-Attention: Contextualize Tokens]
    Attention --> Logits[Linear Layer: Compute Logits]
    Logits --> Softmax[Apply Softmax]
    Softmax --> Probability[Token Probabilities]

3. Loss Functions in LLMs: The AI Report Card 📚✨

What Are Loss Functions?

A loss function measures the difference between the model’s predictions and the ground truth. Smaller loss = better model.

In LLMs, the loss is computed at every token prediction step. For example:

  • Input: “AI is amazing.”
  • Prediction: “AI is boring.”
  • Loss: “Nope, let’s fix that!” 🚫

Types of Loss Functions in LLMs

1. Cross-Entropy Loss: The Gold Standard 🌟

The most common loss function for LLMs is Cross-Entropy Loss, which measures how far the predicted probability distribution is from the true one.


The Math Behind Cross-Entropy

For a single token \( w_t \), the cross-entropy loss is:

\[ \mathcal{L}_{\text{token}} = -\log P(w_t | w_{For an entire sequence:

\[ \mathcal{L}_{\text{sequence}} = -\sum_{t=1}^T \log P(w_t | w_{Where:

  • \( P(w_t | w_{
  • \( T \): Length of the sequence.

Example: For the sentence “AI is amazing,” let’s say the model predicts:

  • \( P(\text{is} | \text{AI}) = 0.9 \)
  • \( P(\text{amazing} | \text{AI is}) = 0.8 \)

The loss:

\[ \mathcal{L}_{\text{sequence}} = - \left( \log(0.9) + \log(0.8) \right) \]

Approximate:

\[ \mathcal{L}_{\text{sequence}} = -(-0.105 - 0.223) = 0.328 \]

2. Label Smoothing: Making It Less Harsh 🤗

To avoid overconfidence in predictions, we can use label smoothing, which adjusts the true distribution slightly:

\[ P_{\text{true}}(w) = (1 - \epsilon) \cdot 1_{\text{correct}} + \frac{\epsilon}{V} \]

Where:

  • \( \epsilon \): Smoothing factor.
  • \( V \): Vocabulary size.

This encourages the model to spread probabilities slightly, reducing overfitting.


3. Negative Log-Likelihood (NLL): A Close Cousin

Cross-Entropy Loss is essentially Negative Log-Likelihood (NLL) for categorical distributions. It works the same way:

\[ \mathcal{L}_{\text{NLL}} = - \sum_{t=1}^T \log P(w_t | w_{It’s commonly used when you already have probabilities computed via softmax.


Why Loss Functions Matter

  1. Feedback Mechanism:
    • Loss functions guide the model to improve over time.
  2. Training Stability:
    • Well-designed loss functions prevent issues like vanishing gradients.
  3. Performance Metric:
    • Lower loss correlates with better predictions.

Numerical Example: Loss Calculation

Let’s compute the loss for a simple sentence.

Input Sentence:

“AI rocks.”

Token Probabilities:

  • \( P(\text{rocks} | \text{AI}) = 0.7 \)
  • \( P(\text{} | \text{AI rocks}) = 0.6 \)

Cross-Entropy Loss:

\[ \mathcal{L} = -\left[\log(0.7) + \log(0.6)\right] \]

Approximate:

\[ \mathcal{L} = -(-0.357 - 0.511) = 0.868 \]

Code Example: Cross-Entropy Loss in PyTorch

Here’s how to compute Cross-Entropy Loss:

import torch
import torch.nn as nn

# Example token predictions (logits) and true labels
logits = torch.tensor([[2.0, 1.0, 0.1], [0.5, 2.5, 0.2]])  # Shape: (batch_size=2, vocab_size=3)
true_labels = torch.tensor([0, 1])  # Correct tokens for each example

# Cross-Entropy Loss
loss_fn = nn.CrossEntropyLoss()
loss = loss_fn(logits, true_labels)

print("Cross-Entropy Loss:", loss.item())

Fun Analogy

Loss functions are like grades in school 📚:

  • If you ace the test (high probability for the correct word), your loss is low.
  • If you flunk (assign high probability to the wrong word), your loss skyrockets. Over time, the model learns to cheat less and study harder! 🎓

Mermaid.js Diagram: Loss Computation in LLMs

graph TD
    Input[Input Tokens] --> Embedding[Embedding Layer]
    Embedding --> Attention[Self-Attention: Contextualize Tokens]
    Attention --> Logits[Linear Layer: Compute Logits]
    Logits --> Softmax[Softmax: Convert Logits to Probabilities]
    Softmax --> LossFunction[Cross-Entropy Loss]
    LossFunction --> Backprop[Backpropagation to Update Model]
Last updated on
Any doubt in content? Ask me anything?
Chat
Hi there! I'm the chatbot. Please tell me your query.