deepseek-ai/DeepEP
DeepEP
DeepEP is a communication library tailored for Mixture-of-Experts (MoE) and expert parallelism (EP). It provides high-throughput and low-latency all-to-all GPU kernels, which are also known as MoE dispatch and combine. The library also supports low-precision operations, including FP8.
To align with the group-limited gating algorithm proposed in the DeepSeek-V3 paper, DeepEP offers a set of kernels optimized for asymmetric-domain bandwidth forwarding, such as forwarding data from NVLink domain to RDMA domain. These kernels deliver high throughput, making them suitable for both training and inference prefilling tasks. Additionally, they support SM (Streaming Multiprocessors) number control.
For latency-sensitive inference decoding, DeepEP includes a set of low-latency kernels with pure RDMA to minimize delays. The library also introduces a hook-based communication-computation overlapping method that does not occupy any SM resource.
Notice: the implementation in this library may have some slight differences from the DeepSeek-V3 paper.
Performance
Normal kernels with NVLink and RDMA forwarding
We test normal kernels on H800 (~160 GB/s NVLink maximum bandwidth), with each connected to a CX7 InfiniBand 400 Gb/s RDMA network card (~50 GB/s maximum bandwidth). And we follow the DeepSeek-V3/R1 pretraining setting (4096 tokens per batch, 7168 hidden, top-4 groups, top-8 experts, FP8 dispatching and BF16 combining).
Type | Dispatch #EP | Bottleneck bandwidth | Combine #EP | Bottleneck bandwidth |
---|---|---|---|---|
Intranode | 8 | 153 GB/s (NVLink) | 8 | 158 GB/s (NVLink) |
Internode | 16 | 43 GB/s (RDMA) | 16 | 43 GB/s (RDMA) |
Internode | 32 | 58 GB/s (RDMA) | 32 | 57 GB/s (RDMA) |
Internode | 64 | 51 GB/s (RDMA) | 64 | 50 GB/s (RDMA) |
News (2025.04.22): with optimizations from Tencent Network Platform Department, performance was enhanced by up to 30%, see #130 for more details. Thanks for the contribution!
Low-latency kernels with pure RDMA
We test low-latency kernels on H800 with each connected to a CX7 InfiniBand 400 Gb/s RDMA network card (~50 GB/s maximum bandwidth). And we follow a typical DeepSeek-V3/R1 production setting (128 tokens per batch, 7168 hidden, top-8 experts, FP8 dispatching and BF16 combining).
Dispatch #EP | Latency | RDMA bandwidth | Combine #EP | Latency | RDMA bandwidth |
---|---|---|---|---|---|
8 | 77 us | 98 GB/s | 8 | 114 us | 127 GB/s |
16 | 118 us | 63 GB/s | 16 | 195 us | 74 GB/s |
32 | 155 us | 48 GB/s | 32 | 273 us | 53 GB/s |
64 | 173 us | 43 GB/s | 64 | 314 us | 46 GB/s |
128 | 192 us | 39 GB/s | 128 | 369 us | 39 GB/s |
256 | 194 us | 39 GB/s | 256 | 360 us | 40 GB/s |
News (2025.06.05): low-latency kernels now leverage NVLink as much as possible, see #173 for more details. Thanks for the contribution!
Quick start
Requirements
- Ampere (SM80), Hopper (SM90) GPUs, or other architectures with SM90 PTX ISA support
- Python 3.8 and above
- CUDA version
- CUDA 11.0 and above for SM80 GPUs
- CUDA 12.3 and above for SM90 GPUs
- PyTorch 2.1 and above
- NVLink for intranode communication
- RDMA network for internode communication
Download and install NVSHMEM dependency
DeepEP also depends on our modified NVSHMEM. Please refer to our NVSHMEM Installation Guide for instructions.
Development
|
|
Installation
|
|
Installation environment variables
NVSHMEM_DIR
: the path to the NVSHMEM directory, disable all internode and low-latency features if not specifiedDISABLE_SM90_FEATURES
: 0 or 1, whether to disable SM90 features, it is required for SM90 devices or CUDA 11TORCH_CUDA_ARCH_LIST
: the list of target architectures, e.g.TORCH_CUDA_ARCH_LIST="9.0"
DISABLE_AGGRESSIVE_PTX_INSTRS
: 0 or 1, whether to disable aggressive load/store instructions, see Undefine behavior PTX usage for more details
Then, import deep_ep
in your Python project, and enjoy!
Network configurations
DeepEP is fully tested with InfiniBand networks. However, it is theoretically compatible with RDMA over Converged Ethernet (RoCE) as well.
Traffic isolation
Traffic isolation is supported by InfiniBand through Virtual Lanes (VL).
To prevent interference between different types of traffic, we recommend segregating workloads across different virtual lanes as follows:
- workloads using normal kernels
- workloads using low-latency kernels
- other workloads
For DeepEP, you can control the virtual lane assignment by setting the NVSHMEM_IB_SL
environment variable.
Adaptive routing
Adaptive routing is an advanced routing feature provided by InfiniBand switches that can evenly distribute traffic across multiple paths. Enabling adaptive routing can completely eliminate network congestion caused by routing conflicts, but it also introduces additional latency. We recommend the following configuration for optimal performance:
- enable adaptive routing in environments with heavy network loads
- use static routing in environments with light network loads
Congestion control
Congestion control is disabled as we have not observed significant congestion in our production environment.
Interfaces and examples
Example use in model training or inference prefilling
The normal kernels can be used in model training or the inference prefilling phase (without the backward part) as the below example code shows.
|
|
Moreover, inside the dispatch function, we may not know how many tokens to receive for the current rank. So an implicit CPU wait for GPU received count signal will be involved, as the following figure shows.
Example use in inference decoding
The low latency kernels can be used in the inference decoding phase as the below example code shows.
|
|
For two-micro-batch overlapping, you can refer to the following figure. With our receiving hook interface, the RDMA network traffic is happening in the background, without costing any GPU SMs from the computation part. But notice, the overlapped parts can be adjusted, i.e., the 4 parts of attention/dispatch/MoE/combine may not have the exact same execution time. You may adjust the stage settings according to your workload.
Roadmap
- AR support
- Refactor low-latency mode AR code
- A100 support (intranode only)
- Support BF16 for the low-latency dispatch kernel
- Support NVLink protocol for intranode low-latency kernels
- TMA copy instead of LD/ST
- Intranode kernels
- Internode kernels
- Low-latency kernels
- SM-free kernels and refactors
Notices
Easier potential overall design
The current DeepEP implementation uses queues for communication buffers which save memory but introduce complexity and potential deadlocks. If you’re implementing your own version based on DeepEP, consider using fixed-size buffers allocated to maximum capacity for simplicity and better performance. For a detailed discussion of this alternative approach, see https://github.com/deepseek-ai/DeepEP/issues/39.
Undefined-behavior PTX usage
- For extreme performance, we discover and use an undefined-behavior PTX usage: using read-only PTX
ld.global.nc.L1::no_allocate.L2::256B
to read volatile data. The PTX modifier.nc
indicates that a non-coherent cache is used. But the correctness is tested to be guaranteed with.L1::no_allocate
on Hopper architectures, and performance will be much better. The reason we guess may be: the non-coherent cache is unified with L1, and the L1 modifier is not just a hint but a strong option, so that the correctness can be guaranteed by no dirty data in L1. - Initially, because NVCC could not automatically unroll volatile read PTX, we tried using
__ldg
(i.e.,ld.nc
). Even compared to manually unrolled volatile reads, it was significantly faster (likely due to additional compiler optimizations). However, the results could be incorrect or dirty. After consulting the PTX documentation, we discovered that L1 and non-coherent cache are unified on Hopper architectures. We speculated that.L1::no_allocate
might resolve the issue, leading to this discovery. - If you find kernels not working on some other platforms, you may add
DISABLE_AGGRESSIVE_PTX_INSTRS=1
tosetup.py
and disable this, or file an issue.
Auto-tuning on your cluster
For better performance on your cluster, we recommend to run all the tests and use the best auto-tuned configuration. The default configurations are optimized on the DeepSeek’s internal cluster.
License
This code repository is released under the MIT License, except for codes that reference NVSHMEM (including csrc/kernels/ibgda_device.cuh
and third-party/nvshmem.patch
), which are subject to NVSHMEM SLA.
Community Forks
- Infrawaves/DeepEP_ibrc_dual-ports_multiQP - Adds multi-QP solution and dual-port NIC support in IBRC transport
Citation
If you use this codebase or otherwise find our work valuable, please cite:
|
|