Skip to content

vllm.model_executor.layers.fused_moe.oracle.unquantized

map_unquantized_backend

map_unquantized_backend(
    runner_backend: MoEBackend,
) -> UnquantizedMoeBackend

Map user's MoEBackend to UnquantizedMoeBackend.

Source code in vllm/model_executor/layers/fused_moe/oracle/unquantized.py
def map_unquantized_backend(runner_backend: MoEBackend) -> UnquantizedMoeBackend:
    """Map user's MoEBackend to UnquantizedMoeBackend."""
    mapping = {
        "triton": UnquantizedMoeBackend.TRITON,
        "flashinfer_trtllm": UnquantizedMoeBackend.FLASHINFER_TRTLLM,
        "flashinfer_cutlass": UnquantizedMoeBackend.FLASHINFER_CUTLASS,
        "aiter": UnquantizedMoeBackend.AITER,
    }
    if backend := mapping.get(runner_backend):
        return backend
    raise ValueError(
        f"moe_backend='{runner_backend}' is not supported for unquantized MoE. "
        f"Expected one of {list(mapping.keys())}."
    )

select_unquantized_moe_backend

select_unquantized_moe_backend(
    moe_config: FusedMoEConfig, use_ep: bool, use_dp: bool
) -> UnquantizedMoeBackend

Select the primary Unquantized MoE backend Note: Shape-specific fallbacks may still occur at runtime.

Source code in vllm/model_executor/layers/fused_moe/oracle/unquantized.py
def select_unquantized_moe_backend(
    moe_config: FusedMoEConfig,
    use_ep: bool,
    use_dp: bool,
) -> UnquantizedMoeBackend:
    """
    Select the primary Unquantized MoE backend
    Note: Shape-specific fallbacks may still occur at runtime.
    """

    def _make_log_backend(backend: UnquantizedMoeBackend):
        return f"Using {backend.value} backend for Unquantized MoE"

    activation_format = (
        mk.FusedMoEActivationFormat.BatchedExperts
        if moe_config.moe_parallel_config.use_batched_activation_format
        else mk.FusedMoEActivationFormat.Standard
    )

    # Check if FlashInfer TRTLLM BF16 MoE is supported
    trtllm_supported, _ = is_supported_config_trtllm_bf16(
        moe_config=moe_config,
        activation_format=activation_format,
    )
    flashinfer_trtllm_available = has_flashinfer() and trtllm_supported
    # FlashInfer CUTLASS MoE is only supported on Hopper and later GPUS
    flashinfer_cutlass_available = (
        has_flashinfer_cutlass_fused_moe()
        and use_ep
        and (not use_dp)
        and current_platform.has_device_capability(90)
    )
    flashinfer_trtllm_moe_enabled = (
        flashinfer_trtllm_available
        and envs.VLLM_USE_FLASHINFER_MOE_FP16
        and envs.VLLM_FLASHINFER_MOE_BACKEND == "latency"
    )
    flashinfer_cutlass_moe_enabled = (
        flashinfer_cutlass_available and envs.VLLM_USE_FLASHINFER_MOE_FP16
    )
    rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()

    # Handle explicit moe_backend from user.
    runner_backend = moe_config.moe_backend
    if runner_backend != "auto":
        requested_backend = map_unquantized_backend(runner_backend)
        if requested_backend == UnquantizedMoeBackend.FLASHINFER_TRTLLM:
            if not flashinfer_trtllm_available:
                raise ValueError(
                    "FlashInfer TRTLLM MoE backend is not available for this "
                    "configuration."
                )
        elif requested_backend == UnquantizedMoeBackend.FLASHINFER_CUTLASS:
            if not flashinfer_cutlass_available:
                raise ValueError(
                    "FlashInfer CUTLASS MoE backend is not available for this "
                    "configuration."
                )
        elif requested_backend == UnquantizedMoeBackend.AITER and not (
            current_platform.is_rocm() and rocm_aiter_moe_enabled
        ):
            raise ValueError(
                "ROCm AITer MoE backend is not available for this configuration."
            )
        logger.info_once(_make_log_backend(requested_backend), scope="local")
        return requested_backend

    if current_platform.is_rocm():
        if rocm_aiter_moe_enabled:
            backend = UnquantizedMoeBackend.AITER
        else:
            backend = UnquantizedMoeBackend.TRITON
    if current_platform.is_cuda():
        if flashinfer_trtllm_moe_enabled:
            backend = UnquantizedMoeBackend.FLASHINFER_TRTLLM
        elif flashinfer_cutlass_moe_enabled:
            backend = UnquantizedMoeBackend.FLASHINFER_CUTLASS
            if trtllm_supported:
                logger.info_once(
                    "FlashInfer TRTLLM MoE is available but not enabled, "
                    "consider setting VLLM_FLASHINFER_MOE_BACKEND=latency "
                    "to enable it for better performance.",
                    scope="local",
                )
        else:
            if not envs.VLLM_USE_FLASHINFER_MOE_FP16 and trtllm_supported:
                logger.info_once(
                    "FlashInfer TRTLLM MoE is available but not enabled, "
                    "consider setting VLLM_USE_FLASHINFER_MOE_FP16=1 "
                    "and VLLM_FLASHINFER_MOE_BACKEND=latency "
                    "to enable it for better performance.",
                    scope="local",
                )
            elif use_ep and (not use_dp):
                logger.info_once(
                    "FlashInfer MoE is available for EP"
                    " but not enabled, consider setting"
                    " VLLM_USE_FLASHINFER_MOE_FP16=1 to enable it.",
                    scope="local",
                )
            elif use_dp:
                logger.info_once(
                    "FlashInfer CUTLASS MoE is currently not available for DP.",
                    scope="local",
                )
            backend = UnquantizedMoeBackend.TRITON
    if current_platform.is_xpu():
        backend = UnquantizedMoeBackend.XPU
    if current_platform.is_cpu():
        backend = UnquantizedMoeBackend.CPU
    if current_platform.is_tpu():
        backend = UnquantizedMoeBackend.TPU
    if current_platform.is_out_of_tree():
        backend = UnquantizedMoeBackend.OOT

    logger.info_once(_make_log_backend(backend), scope="local")
    return backend