Skip to content

vllm.compilation.passes.fusion.sequence_parallelism

SequenceParallelismPass

Bases: VllmPatternMatcherPass

This pass enables sequence parallelism for models. It identifies patterns where an AllReduce operation is followed by an RMSNorm (or RMSNorm and then Quantization) operation. These patterns are replaced with a ReduceScatter operation, followed by a local RMSNorm/Quantization, and then an AllGather operation.

The general transformation is: Input -> AllReduce -> RMSNorm -> Output becomes Input -> ReduceScatter -> RMSNorm -> AllGather -> Output

While this pass itself does not directly yield performance improvements, it lays the groundwork for subsequent fusion passes, such as GEMM + ReduceScatter and AllGather + GEMM fusions. These fusions can significantly reduce communication overhead and improve overall model performance.

This pass splits up the residual tensor across TP ranks and hence divides its size. Because the pattern matcher starts at the end of the graph, the replacement contains a slice that temporarily conforms the input residual to the correct size. After all patterns have been matched, we use a NoOpEliminationPass to clean up what have now become no-op slices.

Note that an older version of the pass did not need this as it operated only on custom rms_norm and fused_rms_norm_add custom ops which did not complain about mismatched shapes during replacement. So this approach has the same assumption that correctness is only maintained if all rms_norm operations are split across ranks.

Correctness-wise, this is approach strictly better than before - before, the graph was incorrect semantically and shape-wise during the pass. With this approach there's only semantic incorrectness during the pass. Both approaches restore a correct graph once all patterns are matched.

Source code in vllm/compilation/passes/fusion/sequence_parallelism.py
class SequenceParallelismPass(VllmPatternMatcherPass):
    """
    This pass enables sequence parallelism for models.
    It identifies patterns where an AllReduce operation is followed by
    an RMSNorm (or RMSNorm and then Quantization) operation.
    These patterns are replaced with a ReduceScatter operation, followed by
    a local RMSNorm/Quantization, and then an AllGather operation.

    The general transformation is:
    Input -> AllReduce -> RMSNorm -> Output
    becomes
    Input -> ReduceScatter -> RMSNorm -> AllGather -> Output

    While this pass itself does not directly yield performance improvements,
    it lays the groundwork for subsequent fusion passes, such as
    GEMM + ReduceScatter and AllGather + GEMM fusions. These fusions can
    significantly reduce communication overhead and improve overall model
    performance.


    This pass splits up the residual tensor across TP ranks and hence divides its size.
    Because the pattern matcher starts at the end of the graph, the replacement
    contains a slice that temporarily conforms the input residual to the correct size.
    After all patterns have been matched, we use a NoOpEliminationPass to clean up
    what have now become no-op slices.

    Note that an older version of the pass did not need this as it operated only on
    custom rms_norm and fused_rms_norm_add custom ops which did not complain about
    mismatched shapes during replacement. So this approach has the same assumption that
    correctness is only maintained if all rms_norm operations are split across ranks.

    Correctness-wise, this is approach strictly better than before - before,
    the graph was incorrect semantically and shape-wise during the pass.
    With this approach there's only semantic incorrectness during the pass.
    Both approaches restore a correct graph once all patterns are matched.
    """

    @enable_fake_mode
    def __init__(self, config: VllmConfig) -> None:
        super().__init__(config)

        # Get min_token_num threshold
        # Read min_token_num from config (calculated during config init)
        self.min_token_num = None
        if config.model_config is not None:
            pass_config = config.compilation_config.pass_config
            self.min_token_num = pass_config.sp_min_token_num

            if self.min_token_num is not None:
                # Take the min to avoid exceeding max_num_batched_tokens
                max_batched = config.scheduler_config.max_num_batched_tokens
                if max_batched is not None:
                    self.min_token_num = min(self.min_token_num, max_batched)
                logger.debug_once(
                    f"Sequence parallelism min token threshold: {self.min_token_num}",
                    scope="global",
                )

        # Used to clean up redundant views created temporarily
        # to circumvent residual shape change issues
        self.noop_cleanup = NoOpEliminationPass(config)
        self.noop_cleanup.pass_name = f"{self.pass_name}.{self.noop_cleanup.pass_name}"

        self.patterns: PatternMatcherPass = PatternMatcherPass(
            pass_name="sequence_parallelism_pass"
        )

        for epsilon in [1e-5, 1e-6]:
            # RMSNorm + Static FP8 quantization patterns
            FirstAllReduceRMSNormStaticFP8Pattern(
                epsilon, self.model_dtype, self.device
            ).register(self.patterns)
            MiddleAllReduceRMSNormStaticFP8Pattern(
                epsilon, self.model_dtype, self.device
            ).register(self.patterns)

            # Normal RMSNorm patterns
            FirstAllReduceRMSNormPattern(
                epsilon, self.model_dtype, self.device
            ).register(self.patterns)

            MiddleAllReduceRMSNormPattern(
                epsilon, self.model_dtype, self.device
            ).register(self.patterns)

        self.dump_patterns(config, self.patterns)

    def is_applicable_for_range(self, compile_range: Range) -> bool:
        """
        Determines if sequence parallelism should be applied for the given
        compile range.

        SP is only beneficial for larger batch sizes where the communication
        overhead is amortized. For small batches, the overhead of splitting
        and gathering tensors across TP ranks outweighs the benefits.

        Returns False (SP disabled) when:
        - Using piecewise compilation with non-concrete or TP-indivisible sizes
        - min_token_num is None (SP disabled for this device/config)
        - The compile range starts below the minimum token threshold
        """
        # For piecewise compilation (not using inductor graph partition),
        # we need concrete sizes that are divisible by TP for correct splitting
        if (
            not self.compilation_config.use_inductor_graph_partition
            and self.compilation_config.splitting_ops
        ):
            tp_size = get_tensor_model_parallel_world_size()
            if not compile_range.is_single_size() or compile_range.end % tp_size != 0:
                return False

        # min_token_num is None when SP is disabled for this device/config
        # (e.g., non-CUDA platform, unsupported GPU, or small hidden_size)
        if self.min_token_num is None:
            return False

        # Only apply SP when batch size meets the minimum threshold
        return compile_range.start >= self.min_token_num

    @VllmInductorPass.time_and_log
    def __call__(self, graph: fx.Graph) -> None:
        self.matched_count = self.patterns.apply(graph)
        logger.debug("Replaced %s patterns", self.matched_count)
        # Clean up reshape nodes
        self.noop_cleanup(graph)

is_applicable_for_range

is_applicable_for_range(compile_range: Range) -> bool

Determines if sequence parallelism should be applied for the given compile range.

SP is only beneficial for larger batch sizes where the communication overhead is amortized. For small batches, the overhead of splitting and gathering tensors across TP ranks outweighs the benefits.

Returns False (SP disabled) when: - Using piecewise compilation with non-concrete or TP-indivisible sizes - min_token_num is None (SP disabled for this device/config) - The compile range starts below the minimum token threshold

Source code in vllm/compilation/passes/fusion/sequence_parallelism.py
def is_applicable_for_range(self, compile_range: Range) -> bool:
    """
    Determines if sequence parallelism should be applied for the given
    compile range.

    SP is only beneficial for larger batch sizes where the communication
    overhead is amortized. For small batches, the overhead of splitting
    and gathering tensors across TP ranks outweighs the benefits.

    Returns False (SP disabled) when:
    - Using piecewise compilation with non-concrete or TP-indivisible sizes
    - min_token_num is None (SP disabled for this device/config)
    - The compile range starts below the minimum token threshold
    """
    # For piecewise compilation (not using inductor graph partition),
    # we need concrete sizes that are divisible by TP for correct splitting
    if (
        not self.compilation_config.use_inductor_graph_partition
        and self.compilation_config.splitting_ops
    ):
        tp_size = get_tensor_model_parallel_world_size()
        if not compile_range.is_single_size() or compile_range.end % tp_size != 0:
            return False

    # min_token_num is None when SP is disabled for this device/config
    # (e.g., non-CUDA platform, unsupported GPU, or small hidden_size)
    if self.min_token_num is None:
        return False

    # Only apply SP when batch size meets the minimum threshold
    return compile_range.start >= self.min_token_num

_SequenceParallelPatternHelper

Helper for sequence parallelism patterns.

Source code in vllm/compilation/passes/fusion/sequence_parallelism.py
class _SequenceParallelPatternHelper:
    """Helper for sequence parallelism patterns."""

    def __init__(
        self,
        epsilon: float,
        dtype: torch.dtype,
        device: str | None,
    ) -> None:
        self.epsilon = epsilon
        self.dtype = dtype
        self.device = device
        self.tp_group = get_tp_group()
        self.tp_size = get_tensor_model_parallel_world_size()

    def _all_reduce(self, x: torch.Tensor) -> torch.Tensor:
        return tensor_model_parallel_all_reduce(x)

    def _reduce_scatter(self, x: torch.Tensor) -> torch.Tensor:
        return torch.ops.vllm.reduce_scatter.default(
            x, dim=0, world_size=self.tp_size, group_name=self.tp_group.unique_name
        )

    def _all_gather(self, x: torch.Tensor) -> torch.Tensor:
        return torch.ops.vllm.all_gather.default(
            x, dim=0, world_size=self.tp_size, group_name=self.tp_group.unique_name
        )

get_sequence_parallelism_threshold

get_sequence_parallelism_threshold(
    hidden_size: int, tp_size: int, element_size: int
) -> int | None

Calculate the minimum token threshold for applying sequence parallelism.

Returns None if sequence parallelism should not be applied based on model size.

Branching logic based on device capability: - Check if hidden_size >= SP_MIN_HIDDEN_SIZE[device_capability] - If not, returns None (SP disabled for small models on this device) - If yes, calculates threshold based on per-GPU size

min_token_num = (min_per_gpu_size_mb * tp_size * MiB) //

(hidden_size * element_size)

Source code in vllm/compilation/passes/fusion/sequence_parallelism.py
def get_sequence_parallelism_threshold(
    hidden_size: int,
    tp_size: int,
    element_size: int,
) -> int | None:
    """
    Calculate the minimum token threshold for applying sequence parallelism.

    Returns None if sequence parallelism should not be applied based on model size.

    Branching logic based on device capability:
    - Check if hidden_size >= SP_MIN_HIDDEN_SIZE[device_capability]
    - If not, returns None (SP disabled for small models on this device)
    - If yes, calculates threshold based on per-GPU size

    Formula: min_token_num = (min_per_gpu_size_mb * tp_size * MiB) //
             (hidden_size * element_size)
    """
    from vllm.platforms import current_platform

    if not current_platform.is_cuda():
        return None

    capability = current_platform.get_device_capability()
    if capability is None:
        return None
    device_capability = capability.to_int()

    # Check if device has configured thresholds
    min_hidden_size = SP_MIN_HIDDEN_SIZE.get(device_capability)
    min_per_gpu_size_mb = SP_MIN_PER_GPU_SIZE_MB.get(device_capability)

    if min_hidden_size is None or min_per_gpu_size_mb is None:
        return None

    # Only apply sequence parallelism for models meeting the size threshold
    if hidden_size < min_hidden_size:
        return None

    MiB = 1024 * 1024
    min_size = min_per_gpu_size_mb * MiB * tp_size
    return int(min_size // (hidden_size * element_size))