Data Parallelism in Attention in SGLang¶
What is Data Parallelism in Attention?¶
⚠ 这里的 dp_size 不是传统的数据并行而是 attentnion dp
 其目的是将一个大的张量并行(TP)组重组为几个较小的 TP 组 ,这些较小的 TP 组又形成一个专门用于注意力层的新的数据并行(DP)组。
- 对于模型中除 MLP 层以外的部分(如 Embedding, Self-Attention),每个数据并行单元(DP_Rank)内部的张量并行规模(TP_Size)设置为 1,即每个 DP_Rank 独立计算这些部分。
 - 对于 MLP 层,所有 DP_Rank 则共同组成一个大的张量并行组(TP_Group),该组的大小等于数据并行的规模(DP_Size)。
 
启动流程¶
- engine 会启动一个子进程执行 
run_data_parallel_control_process - 该函数启动 DataParallelController 并执行 
launch_dp_attention_schedulers- 为每个 dp rank 获取单独的 zmq 通信接口
 - 调用 
launch_tensor_parallel_group 
 -  
launch_tensor_parallel_group对每个 (pp_rank, tp_rank) 启动一个 Schedular 子进程 - 父进程在这里阻塞等待每个子进程通过 pipe 发送初始化信息(通常包含:模型已加载完成、模型的 buffer/kv 配置、最大 token 长度等Text Only```pythonscheduler = Scheduler(
server_args,
port_args,
gpu_id,
tp_rank,
moe_ep_rank,
pp_rank,
dp_rank,
) 
pipe_writer.send(
             {
                 "status": "ready",
                 "max_total_num_tokens": scheduler.max_total_num_tokens,
                 "max_req_input_len": scheduler.max_req_input_len,
             }
         )
 ┌─────────────── Pipeline Stage 0 ───────────────┐
 TP=0→ │ TP Rank 0 │ TP Rank 1 │ TP Rank 2 │
 ├─────────────── Pipeline Stage 1 ───────────────┤
 TP=1→ │ TP Rank 3 │ TP Rank 4 │ TP Rank 5 │
 ├─────────────── Pipeline Stage 2 ───────────────┤
 TP=2→ │ TP Rank 6 │ TP Rank 7 │ TP Rank 8 │
 └────────────────────────────────────────────────┘
4. ModelRunner 初始化时执行 `initialize_dp_attention`,
   - attn_tp_size = tp_size // dp_size
   - attn_dp_rank = tp_rank // attn_tp_size
   - attn_tp_rank = tp_rank % attn_tp_size
   - 并据此创建 attention-specific group coordinator(attn_tp_group)
   - 设置 Gather 时候的 Buffer
## 执行流程
与正常执行类似,都先执行 `recv_requests` 和 `process_input_requests`
1. 在 `get_next_batch_to_run` 中执行 `prepare_mlp_sync_batch`
2. `prepare_mlp_sync_batch` 在 (dp_size, attn_tp_size)维度上收集每个 worker 的 batch/forward-mode 信息(包括 token 数、是否能跑 CUDA graph、logprob 相关长度、是否为 extend、以及 TBO 相关的 all-gather 元数据),然后合并回每个本地 `local_batch`
   - 若本地没任务,但任一其它 DP 副本有任务(表示全局不是完全空闲),则通过 `get_idle_batch()` 获取一个占位 `ScheduleBatch`,便于该 worker 参与后续同步/并行(例如同步 MLP gather 或 KV cache 操作),避免因为某些 worker 缺失导致收集/通信异常。
   - 最后返回 `local_batch` 作为此次的 new_batch
     ```python
         def get_idle_batch(self):
             idle_batch = ScheduleBatch.init_new(
                 [],
                 self.req_to_token_pool,
                 self.token_to_kv_pool_allocator,
                 self.tree_cache,
                 self.model_config,
                 self.enable_overlap,
                 self.spec_algorithm,
             )
             idle_batch.prepare_for_idle()
             return idle_batch
     ```
def prepare_for_idle(self):
        self.forward_mode = ForwardMode.IDLE
        self.input_ids = torch.empty(0, dtype=torch.int64, device=self.device)
        self.seq_lens = torch.empty(0, dtype=torch.int64, device=self.device)
        self.seq_lens_cpu = torch.empty(0, dtype=torch.int64)
        self.orig_seq_lens = torch.empty(0, dtype=torch.int32, device=self.device)
        self.out_cache_loc = torch.empty(0, dtype=torch.int64, device=self.device)
        self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device)
        self.seq_lens_sum = 0
        self.extend_num_tokens = 0
        self.sampling_info = SamplingBatchInfo.from_schedule_batch(
            self,
            self.model_config.vocab_size,
        )
- 变为 
ModelWorkerBatch后调用TpModelWorker::forward_batch_generation转化为ForwardBatch,调用ModelRunner::forward->ModelRunner::_forward_raw在这里进行 padding - 实际上是调用 
FowardBatch::prepare_mlp_sync_batch
- 把各个 dp rank 的 num_token 信息对齐 Attention-tp-size- 确保每个 DP rank 的 token 数量是 
attn_tp_size的倍数 - 为后续可能的 reduce-scatter 操作做准备(reduce-scatter 要求可被整除的分割)
 - 每个 dp rank 本地的 batch 会根据填充模式被 padding 到指定的大小(bs = self.batchsize = num_tokens).
 - 当 
dp_padding_mode设置为 max 时,padding 到的 batchsize 大小为所有 dp rank 上最大的 local batch,同时对齐到 attention_tp_size; - 当 
dp_padding_mode设置为 sum 时,padding 到的 batchsize 大小为所有 dp rank 上 local batch 之和,即 global batch size - is_extend_in_batch 设置 SUM,cuda_graph 设置 MAX
 
 - 确保每个 DP rank 的 token 数量是 
 - 在完成 dp attention batch padding 后,根据 batch 类型调用
self.forward_decode进行模型推理。在完成推理后,调用forward_batch.post_forward_mlp_sync_batch(ret)
- 复原被临时修改的属性
- 将 padding 后的 ForwardBatch 还原,具体做法是在 prepare 时会记录原 batchsize,此时对结果进行切片 
# prepare
setattr(self, "_original_forward_mode", self.forward_mode)
setattr(self, "_original_batch_size", self.batch_size)
# post
self.forward_mode = getattr(self, "_original_forward_mode", self.forward_mode)
self.batch_size = getattr(self, "_original_batch_size", self.batch_size)
logits_output.next_token_logits = logits_output.next_token_logits[:bs]
logits_output.hidden_states = logits_output.hidden_states[:bs]
- 最后的处理流程与常规无差别,调用 
process_batch_result结束此轮调度 
Why DP Attention?¶
它用一次相对廉价的内存通信(All-Gather KV Cache)换取了Attention 的并行加速,同时还保持了 MLP 层(\(O(N)\))的零通信数据并行效率。
Why Padding?¶
在 Attention 层与 MoE 层之间需要进行同步通信,对各 dp rank atttention 部分计算得到的 hidden_state 集合进行同步,如果采用 Allgather 或者 Allreduce 通信,其对数据性状有一定的要求
- allgather:要求每个 sub-batchsize 大小相同。因此需要把 local batchsize padding 到 max batchsize,且最后 gatheredbuffer 大小为:max_batchsize * sync_group_size

 - allreduce:每个 dp rank 上都有完整的数据集合(对应 padding 到 sum batchsize). 最后 gatheredbuffer 大小为:sum_batchsize

 
Padding Problems¶
这种为了“正确性”而做的 padding,反过来又会严重影响性能,主要体现在以下两方面:
A. 负载不均与“掉队者” (Load Imbalance & Stragglers)¶
GPU_0 有 4 个真实的请求需要计算,而 GPU_1, 2, 3 只有 3 个真实的请求(和 1 个几乎不耗时的虚拟请求)。
- 结果:
GPU_1, 2, 3会很快完成它们的本地计算。 - 等待:然后,它们在 
All-Gather这个同步点(Barrier)上被迫空转(Idle),等待GPU_0完成它那份更重的工作。 GPU_0成为了这个批次的“掉队者” (Straggler),它拖慢了所有其他 GPU 的进度。
结论: 这导致了极低的 GPU 利用率。在 dp_size=4 的情况下,可能有 ¾ 的 GPU 在大部分时间里处于空闲等待状态。
B. 计算和内存的额外开销 (Overhead)¶
- 计算开销:
 
- 尽管是“虚拟”请求,但系统仍然需要启动 CUDA Kernel 来处理它们,这会带来一定的调度开销。
 - 更重要的是,
All-Gather操作本身。所有 GPU 现在都必须交换一个“大小为 4”的批次数据,而不是它们“实际”的[4, 3, 3, 3]。通信量被放大了。 
- 内存开销:
 
All-Gather之后,每个 GPU 都需要分配足够大的内存缓冲区来接收来自所有其他 rank 的数据。GPU_1明明只需要4+3+3+3 = 13个请求的数据,但它必须按照4+4+4+4 = 16个请求(即max_batch_per_rank * dp_size)的最大可能性来分配内存。- 这导致了显存(VRAM)的浪费。
 
