Implementing On-Policy Distillation: Lessons from Building OPD in VeRL

18 minute read

Published:

Introduction

The standard recipe for distilling a strong teacher into a smaller student is straightforward: generate reasoning traces from the teacher, fine-tune the student on those traces (SFT), then run RL. As the teacher traces are off-policy, the student might forget its own knowledge when trying to fit the teacher’s output and this causes a distribution-shift during inference (the student might encounter states that the teacher has never seen and will not be able to recover). One of the other side-effects is that the student inherits the teacher’s verbosity patterns from static data, then must unlearn them during RL.

On-Policy Distillation (OPD) [1] seems promising as it operates on student outputs, thus removing the train-test mismatch issue. A natural idea is to integrate OPD with regular RL training since it can reuse most of the RL infrastructure. Conceptually, OPD can be directly implemented by replacing the reference policy used to compute the KL divergence term with the teacher policy. The teacher then evaluates the student’s own rollouts and provides a KL divergence signal that pushes the student toward better reasoning—selectively, on prompts where the student’s pass rate is low. This idea has been explored concurrently in KDRL [2], which proposes a unified framework combining knowledge distillation with reinforcement learning but does not release an implementation. The idea is simple. The implementation is a bit more than that.

This post is the engineering story. We built OPD for the VeRL training framework, and document the architecture, the key implementation pitfalls, the design choice between implementing KL as a loss term vs. as advantage replacement, and the constraints we discovered along the way.


Architecture Overview

Why Not Just Swap the Reference Policy?

As mentioned in the introduction, the conceptual pitch for OPD is simple: standard RL training already computes KL divergence against a reference policy (a frozen snapshot of the initial model) to prevent the student from drifting too far. Just point that KL term at the teacher instead, and the student gets pulled toward the teacher’s distribution rather than its own starting point.

The primary obstacle is hardware. The reference policy is the same architecture as the student, so it can share the same GPUs (or be offloaded cheaply). A 72B teacher serving a 4B student cannot. The teacher must run on separate hardware, which means network communication, serialization, and a client-server protocol that the reference policy path was never designed for. This is the main architectural requirement that drives everything else.

Optionally, the teacher can be a closed source model queried through API, or a shared service that is used by multiple RL workloads.

Given these requirements, OPD needs to be a separate subsystem. The teacher runs on dedicated hardware behind a vLLM server, communicating via ZeroMQ:

Student Training                        Teacher Server 
┌──────────────────────────┐            ┌──────────────────────┐
│ 1. Generate rollouts     │            │  vLLM Engine         │
│ 2. Compute rewards       │            │  - Teacher model     │
│ 3. Compute pass rates    │            │  - Independent TP    │
│ 4. Create eligibility    │  ZeroMQ    │  - Logprob compute   │
│    masks                 │───────────►│                      │
│                          │◄───────────│                      │
│ 5. Fetch teacher logprobs│            └──────────────────────┘
│ 6. RL loss + KD loss     │
└──────────────────────────┘

The eligibility mask determines which rollouts receive teacher guidance. For each prompt, we compute the pass rate—the fraction of rollouts that receive a positive reward. Prompts with low pass rates are “hard”: the student struggles with them. Only failed rollouts for hard prompts are eligible for OPD, focusing the teacher signal where it is most needed.


Pitfall 1: Legacy vs. New Worker—Know Your Code Path

The Problem

We implemented OPD, launched training, and saw encouraging metrics in W&B: opd/frac_opd_samples at ~5%, opd/frac_underperforming_prompts at ~5-10%. These are computed in the trainer before the update step. They confirmed the masking logic worked.

But four worker-level metrics were missing: opd/kl_loss, opd/kd_coef, opd/num_eligible_samples, opd/frac_tokens_with_kd. No errors. No crashes. Training continued normally as pure RL—no teacher guidance was applied at all.

The root cause: VeRL has two worker implementations with different loss computation paths.

ImplementationConfig SettingLoss Location
Legacy workeruse_legacy_worker_impl = "auto" (default)Hardcoded inline in dp_actor.py:update_policy()
New workeruse_legacy_worker_impl = "disable"Pluggable via losses.py:ppo_loss() with set_loss_fn()

We implemented OPD loss in losses.py—the new worker’s pluggable loss path. But the default configuration uses the legacy worker, where loss computation is hardcoded in update_policy(): policy gradient + entropy + KL penalty, with no OPD term. Our code was never called.

Why It Was Hard to Detect

Several factors conspired to make this silent:

  1. Partial metrics created false confidence. The trainer-level metrics (opd/frac_opd_samples) are computed before the worker step and appeared correctly, suggesting the pipeline was active.

  2. No crashes or errors. The legacy worker’s key selection (select_keys) didn’t include OPD keys, so teacher logprobs were silently dropped from the batch. The loss function in losses.py was never called, so its conditions were never evaluated.

  3. Training progressed normally. Without OPD loss, training just ran pure RL. The model still learned, accuracy improved, everything looked reasonable.

The Fix

Implement OPD loss directly in the legacy worker’s update_policy() in dp_actor.py, after the existing KL penalty section.

Also add OPD keys to select_keys so the data survives batch serialization:

# In dp_actor.py: Include OPD keys in worker data
if "teacher_log_probs" in data.batch.keys():
    select_keys.extend([
        "teacher_log_probs", "opd_eligibility_mask",
        "opd_horizon_mask", "prompts"
    ])

Pitfall 2: Scalar Logprobs—Full Distributions Are Not Needed

The Math Dictates the Implementation

OPD adds a KL divergence term between student and teacher distributions. Conceptually, computing full KL divergence requires the complete vocabulary distribution from both models:

KL(p || q) = sum_x p(x) * log(p(x) / q(x))

For Qwen models with a vocabulary of ~152k tokens, this means transferring a 152k-dimensional vector per token position from the teacher server. At thousands of positions per sequence and hundreds of sequences per batch, this is impractical.

Both KL estimators we use sidestep this entirely. They approximate KL divergence using only the scalar logprobs of the actual generated tokens x_t:

K1:  KL ≈ log p_student(x_t) - log p_teacher(x_t)
K2:  KL ≈ 0.5 * (log p_student(x_t) - log p_teacher(x_t))^2

Whether you choose K1 or K2 (discussed later in the loss vs. advantage section), both require exactly one scalar per position from the teacher—not a 152k-dimensional distribution.

The Mistake

Our initial teacher server code missed this and requested full vocabulary logprobs from vLLM:

# BROKEN: Requesting full vocab
sampling_params = SamplingParams(
    prompt_logprobs=vocab_size,  # 151,669 tokens!
)

This hit vLLM’s limit: "Requested prompt logprobs of 151669, which is greater than max allowed: 20".

The fix requests only the top-1 logprob per position, then extracts the actual generated token’s probability:

# FIXED: Request only scalar logprobs
sampling_params = SamplingParams(
    prompt_logprobs=1,  # Just the actual next token
)

# returned prompt_logprobs is a dictionary 
# Extract scalar logprob for each position
for pos in range(1, len(token_ids)):
    actual_token_id = token_ids[pos]
    logprob = prompt_logprobs[pos][actual_token_id].logprob
    token_logprobs.append(logprob)

A note on shape alignment. Once teacher logprobs reached the worker, we hit a shape mismatch: the student worker returns response-only logprobs [batch, 8192], but the teacher server returns full-sequence logprobs such as [batch, 9216] (prompt + response). Teacher logprobs need to be sliced to the response portion (teacher_log_probs[:, prompt_len:]) before computing KL.

Takeaway

Always review the mathematical requirements of your objective before implementing the data pipeline. Both K1 and K2 only need scalar logprobs—but the server code was written as if full KL were needed. Understanding the estimator formulas before writing the server would have avoided both the efficiency problem and the vLLM limit entirely.

Design Choice: KL as Loss vs. KL as Advantage

Once the infrastructure works—teacher server responds, logprobs reach the worker, shapes align—you still face a design question: how should the teacher’s KL signal enter the training objective?

We implemented two approaches. Both are viable, but they use different KL estimators and have different balancing challenges.

Approach 1: KL as a Separate Loss Term

Add a KD loss term alongside the policy gradient loss:

L = L_PG + beta * L_KD

where L_KD = mean(KL(student || teacher) * mask).

This approach should use the K2 estimator:

KL_K2 = 0.5 * (log p_student - log p_teacher)^2

K2 is appropriate here because it functions as a loss: always non-negative, with gradients that vanish when the student matches the teacher. The squared term gives it the right properties for minimization.

The balancing problem. K2 values are large when the student-teacher gap is significant. With a logprob difference of 2–3 nats (common early in training), K2 produces values of 2–4.5 per token. Meanwhile, the policy gradient loss is typically around 0.01. A coefficient of beta = 0.1 means the KD term contributes ~0.2–0.45, dominating the PG loss by an order of magnitude.

This makes coefficient tuning fragile. Too high and teacher guidance overwhelms the RL signal. Too low and it has no effect. The right value depends on training stage, problem difficulty, and student-teacher gap—all of which change over time.

We also found that ratio clipping (clamping the teacher/student probability ratio, similar to PPO’s clip) helps stabilize this approach, but adds another hyperparameter.

Approach 2: KL as Advantage Replacement

Instead of adding a separate loss, replace the RL advantages for hard prompts with teacher-derived advantages:

# For hard prompts: replace RL advantages with teacher signal
# Negate K1 so that tokens the teacher favors get positive advantage
opd_advantages = -(log p_student - log p_teacher) * horizon_mask

advantages = where(
    eligible,          # hard prompt, failed rollout
    opd_advantages,    # teacher guidance
    rl_advantages      # standard RL
)

The horizon_mask limits OPD to the first K tokens of each response, since early tokens have more influence on the reasoning trajectory.

This approach should use the K1 estimator:

KL_K1 = log p_student - log p_teacher

K1 is the right choice here because the KL signal is being used as a reward, not a loss. The advantage is the negation of K1: -(log p_student - log p_teacher) = log p_teacher - log p_student. This is directly interpretable as a per-token reward: positive when the teacher assigns higher probability to the generated token (the teacher “approves”), negative when it assigns lower probability. Squaring this (K2) would destroy the sign, losing the directional information that makes advantages useful.

No coefficient to balance. Since teacher advantages replace RL advantages (rather than being added to a different loss), the policy gradient loss operates on a single unified set of advantages. There’s no beta to tune between competing loss terms.

Why Both Require Careful Normalization

In standard GRPO, advantages are normalized within a batch (zero mean, unit variance). This normalization is critical for stable policy gradient updates.

When OPD advantages enter the picture, normalization must be handled carefully regardless of approach.

In the loss approach, the KD loss has a fundamentally different scale from the PG loss. The coefficient beta attempts to bridge this, but it’s a static scalar applied to a dynamic quantity.

In the advantage approach, the issue is subtler. OPD advantages (raw KL differences) and RL advantages (normalized outcome rewards) have different distributions. If GRPO normalizes the combined set of advantages—some from RL, some from teacher KL—the two populations contaminate each other’s statistics. A batch dominated by hard prompts (many OPD advantages) shifts the normalization in ways that distort the RL signal for easy prompts, and vice versa.

Normalize OPD advantages separately. Compute mean and variance over OPD-eligible tokens only, normalize those, then combine with separately-normalized RL advantages. This ensures each population is properly scaled before they enter the same policy gradient computation.

Summary

 KL as LossKL as Advantage
KL estimatorK2: 0.5 * (log p_s - log p_t)^2K1: log p_s - log p_t
Why this estimatorNon-negative, proper loss for minimizationPreserves sign, interpretable as per-token reward
BalancingCoefficient beta between PG and KD lossesNo coefficient; separate normalization instead
Integration pointWorker loss function (dp_actor.py)Trainer, after advantage computation (ray_trainer.py)
Hard/easy separationMasking within losstorch.where on advantages

Constraint: Shared Tokenizer and Chat Template

The Requirement

Our current OPD implementation requires that the student and teacher share the same tokenizer. Both KL estimators compute a difference of logprobs at each position t for the generated token x_t. This only makes sense if both models agree on what token ID x_t represents.

But sharing a tokenizer is necessary and not sufficient. The student and teacher must also share the same chat template.

Why Chat Template Matters

Even when two models use the same tokenizer (same vocabulary, same BPE merges), they may use different chat templates. Consider Qwen3-4B-Base and Qwen3-32B-Instruct:

  • Same tokenizer: Both use the Qwen3 tokenizer with identical vocabulary
  • Different templates: The instruct model wraps generations in <|im_start|>, <|im_end|>, and may use <think> tags; the base model produces raw text

When computing KL divergence, the teacher assigns high probability to template tokens like <|im_start|> at positions where the student has never learned to produce them. The token IDs match (same tokenizer), but the probability distributions are fundamentally misaligned. KL divergence explodes, gradients become unstable, and training diverges.

Practical Solutions

We identified two approaches:

Option 1: Use a format-compatible teacher. Run RL directly from the same base model checkpoint (without chat template SFT), so both student and teacher produce raw text. This works but requires training an additional large model, defeating the purpose of efficient training.

Option 2: Pre-align the student via Rejection Fine-Tuning (RFT). Before OPD training, generate rollouts from the base student, filter to correct solutions, and fine-tune the student on these correct outputs formatted with the teacher’s template. The student learns to produce <think> tags and chat formatting, making its output structure compatible with the teacher’s distribution. This is the approach we use—we discuss the RFT procedure and its interaction with RL training in an upcoming post.

Current Limitation

Supporting different tokenizers between student and teacher would require token-level alignment (mapping between vocabularies), which we have not implemented. For now, OPD is restricted to student-teacher pairs that share a tokenizer and—after any necessary pre-alignment—produce structurally compatible outputs.

Relation to VeRL’s GKD Recipe

VeRL ships a GKD (Generalized Knowledge Distillation) recipe that performs on-policy distillation with a similar teacher server architecture (vLLM + ZeroMQ). Our implementation diverges in three ways.

Synchronous vs. off-policy scheduling. Our OPD runs synchronously: the teacher scores the current step’s rollouts before the actor update begins. GKD overlaps computation phases using async schedulers (one-step-off, two-step-off): the actor update for step N runs concurrently with rollout generation for step N+1, and teacher inference is pipelined across steps. This delivers significant throughput gains but means the rollout policy is one or two steps behind the current parameters—slightly off-policy.

Top-k distributions vs. scalar logprobs. GKD fetches top-k token distributions from the teacher (typically k=10), storing both logprobs and token indices per position. This enables computing KL divergence over the teacher’s probability mass instead of a single token. However the KDRL paper reported this setup to be unstable for training, so we did not implement this.

KL computed in Megatron engine vs. in update_policy(). GKD computes KL divergence inside Megatron’s forward pass via a custom logits_processor and a hand-written TP-aware autograd function that handles all-reduce across tensor-parallel ranks. This is efficient but ties the implementation to the Megatron backend. Our OPD operates on logprobs after the forward pass, in the actor’s update_policy() (or in the trainer for the advantage approach). This works with VeRL’s FSDP backend and doesn’t require custom autograd, at the cost of not leveraging Megatron’s pipeline overlap or native TP-aware KL computation.

Conclusion

The OPD implementation touched 10 files across a distributed system spanning trainer, workers, and an external teacher server. The dominant failure mode was silent degradation: every issue resulted in a condition quietly evaluating to false, with no errors and partial metrics suggesting everything was working.

Four lessons stand out:

  1. Trace the code path from config to gradient. In a framework with multiple worker implementations, the clean pluggable interface may not be the one that’s active.

  2. Let the math dictate the data pipeline. Both K1 and K2 estimators only need scalar logprobs, not full vocabulary distributions. Understanding this upfront avoids both efficiency problems and API limits.

  3. Match the KL estimator to the integration point. K2 (squared) is a proper loss for minimization. K1 (linear) preserves sign information needed for advantages. Using the wrong estimator for the wrong approach either destroys useful signal or creates unstable optimization.

  4. Format compatibility is a hard constraint. Shared tokenizer is necessary but not sufficient—student and teacher must produce structurally compatible outputs for KL divergence to be meaningful.

Two broader challenges remain. First, balancing the teacher signal with RL rewards is difficult regardless of design approach. Whether it’s a loss coefficient or advantage normalization, getting this wrong means either the teacher overwhelms RL or has no effect at all.

Second, OPD fundamentally assumes that the teacher and student policies are similar enough for the KL signal to be useful. Format compatibility (shared tokenizer and chat template) is the most visible aspect of this, but the issue runs deeper. Since we evaluate the teacher’s logprob on the student’s sampled tokens, tokens that the student generates but the teacher would rarely produce yield extreme logprob differences and noisy KL estimates. The more the student’s distribution diverges from the teacher’s—whether due to different training stages, different capabilities, or different reasoning styles—the noisier the OPD signal becomes. This is a fundamental limitation of on-policy distillation with scalar KL estimators: the signal quality depends on the overlap between the two policies.

References

[1] Agarwal et al. (2024). On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes. ICLR 2024.

[2] Xu et al. (2025). KDRL: Post-Training Reasoning LLMs via Unified Knowledge Distillation and Reinforcement Learning. https://arxiv.org/abs/2506.08946

[3] VeRL: Volcano Engine Reinforcement Learning for LLMs. https://github.com/volcengine/verl