In reinforcement learning algorithms, we often need to compute the KL divergence between the new policy distribution and the old reference policy. This is commonly used in PPO, GRPO or any other policy gradient algorithm. This blog will cover the approximation of KL divergence as proposed by John Schulman in his blog http://joschu.net/blog/kl-approx.html.
Problem
We want to compute KL divergence (which is a measure of distance between two distributions) between two LLM policies.
Wehn LLM has vocab size V and sequence length S, we have logits of shape .
For each token position,
and the true KL diverence for the distrubtion will be mean over the above computed per token KL divergence.
Computing the above expression is very memory intensive. For example, if V is 32k and S is 10k and batch size is 32, the output logits would consume . And we would need to save the same matrix for reference policy.
"""
# Vocabulary size (V)
# Sequence length (S)
"""
vocab_size = 1000
seq_length = 50
"""
-----------------------------
Step 1: Generate Original Logits
-----------------------------
Simulate logits for the original model output (for S tokens)
Shape: [seq_length, vocab_size]
"""
logits_q = torch.randn(seq_length, vocab_size)
q_prob = F.softmax(logits_q, dim=1)
"""
-----------------------------
Step 2: Simulate a Small Weight Update
-----------------------------
Create a small perturbation to simulate an update in the model weights.
Adding this delta to the original logits produces the new logits.
"""
delta = 0.1 * torch.randn(seq_length, vocab_size)
logits_p = logits_q + delta
p_prob = F.softmax(logits_p, dim=1)
"""
-----------------------------
Step 3: Compute the True KL Divergence
-----------------------------
For each token position, compute the KL divergence between q and p:
KL(q||p) = sum_{i in V} q[i] * (log q[i] - log p[i])
"""
kl_per_token = (q_prob * (torch.log(q_prob) - torch.log(p_prob))).sum(dim=1)
"""
Average over all token positions
"""
true_kl = kl_per_token.mean().item()
Therefore, we need to find a way to approximate the above expression.
Let's start with the fact that we don't want to store the complete matrix and instead want to approximate KL diverence through samples from this distribution. In the case of LLM decoding, samples can be considered as output tokens using greedy or some other decoding technique. In the below example we simulate tokens selected using multinomial distribution.
"""
-----------------------------
Step 4: Monte Carlo Sampling per Token
-----------------------------
For each token position (each row), sample one index from the vocabulary using q's distribution.
This simulates having a 1xS vector of sampled tokens.
"""
samples = torch.multinomial(q_prob, num_samples=1, replacement=True)
samples = samples.squeeze(1)
indices = torch.arange(seq_length)
q_samples = q_prob[indices, samples]
p_samples = p_prob[indices, samples]
Approximation 1
Using Monte Carlo sampling formulation , the first naive approximation to the true KL divergence is . This means you have 1 sample from the distribution.
"""
Compute the ratio r = p(x) / q(x) and its logarithm.
"""
r = p_samples / q_samples
log_r = torch.log(r)
k1 = -log_r