Summary
Tuning speculative decoding using Llama 3 8B as the draft model and Llama 3 70B as the target to reduce decode latency. Measured performance at different batch sizes, number of speculative tokens, and tasks.
What I Did
- Set up speculative decoding with Llama 3 8B as the draft model and Llama 3 70B as the target.
- Tuned the number of speculative tokens from 3 to 10 and measured acceptance rate and latency.
- Compared speculative decoding vs standard autoregressive at batch sizes 1, 4, 8.
Commands Used
`trtllm-build --checkpoint_dir ./llama3-70b --spec_decode_mode draft_target --draft_model_dir ./llama3-8b --num_draft_tokens 5`
`python spec_decode_bench.py --target llama3-70b --draft llama3-8b --num_spec_tokens 5 --batch_size 1 --output_len 256`
`python acceptance_rate_eval.py --dataset humaneval --spec_tokens 5`
Lessons Learned
- Speculative decoding is a latency optimization, not a throughput optimization.
- Draft and target models should share vocabulary and have similar output distributions.
- The sweet spot for num_spec_tokens is typically between 4-6, but can vary by task type.
- Only use speculative decoding for interactive, latency-sensitive applications.