Skip to content

vllm.model_executor.layers.fused_moe.experts.trtllm_nvfp4_moe

TrtLlmNvFp4ExpertsBase

NvFp4 TRTLLM-Gen MoE kernels. Supports modular and monolithic interface.

Source code in vllm/model_executor/layers/fused_moe/experts/trtllm_nvfp4_moe.py
class TrtLlmNvFp4ExpertsBase:
    """
    NvFp4 TRTLLM-Gen MoE kernels. Supports modular and monolithic interface.
    """

    def __init__(
        self,
        moe_config: FusedMoEConfig,
        quant_config: FusedMoEQuantConfig,
    ):
        self.moe_config = moe_config
        self.quant_config = quant_config

        self.routing_method_type = self.moe_config.routing_method
        self.topk = moe_config.experts_per_token
        self.intermediate_size_per_partition = (
            moe_config.intermediate_size_per_partition
        )
        self.hidden_dim = moe_config.hidden_dim
        self.local_num_experts = moe_config.num_local_experts
        self.ep_rank = moe_config.moe_parallel_config.ep_rank

        assert self.quant_config.g1_alphas is not None
        assert self.quant_config.a2_gscale is not None
        if moe_config.is_act_and_mul:
            # g1_alpha_s = a13_scale * w13_scale_2
            # a2_gscale = (1 / a2_scale)
            # g1_scale_c = a13_scale * w13_scale_2 / a2_scale
            self.g1_scale_c = self.quant_config.g1_alphas * self.quant_config.a2_gscale
        else:
            self.g1_scale_c = (
                torch.ones_like(self.quant_config.a1_gscale)
                * self.quant_config.a2_gscale
            )

    @staticmethod
    def _supports_current_device() -> bool:
        """Supports only Blackwell-family GPUs."""
        p = current_platform
        return p.is_cuda() and p.is_device_capability_family(100)

    @staticmethod
    def _supports_no_act_and_mul() -> bool:
        """Supports non-gated MoE (i.e. Nemotron-Nano)."""
        return True

    @staticmethod
    def _supports_quant_scheme(
        weight_key: QuantKey | None,
        activation_key: QuantKey | None,
    ) -> bool:
        """Supports Nvfp4 quantization."""
        SUPPORTED_W_A = [
            (kNvfp4Static, kNvfp4Dynamic),
        ]
        return (weight_key, activation_key) in SUPPORTED_W_A

    @staticmethod
    def _supports_activation(activation: MoEActivation) -> bool:
        """Supports only SiLU and RELU^2 non-gated activation."""
        return activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]

    @staticmethod
    def _supports_shape(hidden_dim: int) -> bool:
        """Requires hidden dim to be multiple of 512."""
        return hidden_dim % 512 == 0

    @staticmethod
    def activation_format() -> mk.FusedMoEActivationFormat:
        return mk.FusedMoEActivationFormat.Standard

    def supports_chunking(self) -> bool:
        return False

    def supports_expert_map(self) -> bool:
        return False

_supports_activation staticmethod

_supports_activation(activation: MoEActivation) -> bool

Supports only SiLU and RELU^2 non-gated activation.

Source code in vllm/model_executor/layers/fused_moe/experts/trtllm_nvfp4_moe.py
@staticmethod
def _supports_activation(activation: MoEActivation) -> bool:
    """Supports only SiLU and RELU^2 non-gated activation."""
    return activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]

_supports_current_device staticmethod

_supports_current_device() -> bool

Supports only Blackwell-family GPUs.

Source code in vllm/model_executor/layers/fused_moe/experts/trtllm_nvfp4_moe.py
@staticmethod
def _supports_current_device() -> bool:
    """Supports only Blackwell-family GPUs."""
    p = current_platform
    return p.is_cuda() and p.is_device_capability_family(100)

_supports_no_act_and_mul staticmethod

_supports_no_act_and_mul() -> bool

Supports non-gated MoE (i.e. Nemotron-Nano).

Source code in vllm/model_executor/layers/fused_moe/experts/trtllm_nvfp4_moe.py
@staticmethod
def _supports_no_act_and_mul() -> bool:
    """Supports non-gated MoE (i.e. Nemotron-Nano)."""
    return True

_supports_quant_scheme staticmethod

_supports_quant_scheme(
    weight_key: QuantKey | None,
    activation_key: QuantKey | None,
) -> bool

Supports Nvfp4 quantization.

Source code in vllm/model_executor/layers/fused_moe/experts/trtllm_nvfp4_moe.py
@staticmethod
def _supports_quant_scheme(
    weight_key: QuantKey | None,
    activation_key: QuantKey | None,
) -> bool:
    """Supports Nvfp4 quantization."""
    SUPPORTED_W_A = [
        (kNvfp4Static, kNvfp4Dynamic),
    ]
    return (weight_key, activation_key) in SUPPORTED_W_A

_supports_shape staticmethod

_supports_shape(hidden_dim: int) -> bool

Requires hidden dim to be multiple of 512.

Source code in vllm/model_executor/layers/fused_moe/experts/trtllm_nvfp4_moe.py
@staticmethod
def _supports_shape(hidden_dim: int) -> bool:
    """Requires hidden dim to be multiple of 512."""
    return hidden_dim % 512 == 0

TrtLlmNvFp4ExpertsModular

Bases: TrtLlmNvFp4ExpertsBase, FusedMoEExpertsModular

Modular version of the implementation (just the experts).

Source code in vllm/model_executor/layers/fused_moe/experts/trtllm_nvfp4_moe.py
class TrtLlmNvFp4ExpertsModular(TrtLlmNvFp4ExpertsBase, mk.FusedMoEExpertsModular):
    """
    Modular version of the implementation (just the experts).
    """

    @staticmethod
    def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
        """The modular implementation supports all parallel configs."""
        return True

    def workspace_shapes(
        self,
        M: int,
        N: int,
        K: int,
        topk: int,
        global_num_experts: int,
        local_num_experts: int,
        expert_tokens_meta: mk.ExpertTokensMetadata | None,
        activation: MoEActivation,
    ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
        # The workspaces for this implementation are managed by flashinfer.
        workspace1 = (0,)
        workspace2 = (0,)

        # Hidden states are Nvfp4, packed into int8 dtype, so we
        # need to multiply K by 2 to get the output shape right.
        assert self.hidden_dim == K * 2
        output = (M, self.hidden_dim)

        return (workspace1, workspace2, output)

    def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
        return TopKWeightAndReduceNoOP()

    def apply(
        self,
        output: torch.Tensor,
        hidden_states: torch.Tensor,
        w1: torch.Tensor,
        w2: torch.Tensor,
        topk_weights: torch.Tensor,
        topk_ids: torch.Tensor,
        activation: MoEActivation,
        global_num_experts: int,
        expert_map: torch.Tensor | None,
        a1q_scale: torch.Tensor | None,
        a2_scale: torch.Tensor | None,
        workspace13: torch.Tensor,
        workspace2: torch.Tensor,
        expert_tokens_meta: mk.ExpertTokensMetadata | None,
        apply_router_weight_on_input: bool,
    ):
        assert activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
        assert a1q_scale is not None
        assert self.quant_config.w1_scale is not None
        assert self.quant_config.w2_scale is not None

        # Pack topk ids and weights into format expected by the kernel.
        packed_tensor = (topk_ids.to(torch.int32) << 16) | topk_weights.to(
            torch.bfloat16
        ).view(torch.int16)

        # trtllm_fp4_block_scale_routed_moe does not support autotuning
        # so skip this kernel during dummy run for autotuning.
        import vllm.utils.flashinfer as fi_utils

        if fi_utils._is_fi_autotuning:
            return hidden_states

        # Invoke kernel.
        flashinfer.fused_moe.trtllm_fp4_block_scale_routed_moe(
            topk_ids=packed_tensor,
            routing_bias=None,
            hidden_states=hidden_states,
            hidden_states_scale=a1q_scale.view(torch.float8_e4m3fn).reshape(
                *hidden_states.shape[:-1], -1
            ),
            gemm1_weights=w1,
            gemm1_weights_scale=self.quant_config.w1_scale.view(torch.float8_e4m3fn),
            gemm1_bias=None,
            gemm1_alpha=None,
            gemm1_beta=None,
            gemm1_clamp_limit=None,
            gemm2_weights=w2,
            gemm2_weights_scale=self.quant_config.w2_scale.view(torch.float8_e4m3fn),
            gemm2_bias=None,
            output1_scale_scalar=self.g1_scale_c,
            output1_scale_gate_scalar=self.quant_config.g1_alphas,
            output2_scale_scalar=self.quant_config.g2_alphas,
            num_experts=global_num_experts,
            top_k=self.topk,
            n_group=0,
            topk_group=0,
            intermediate_size=self.intermediate_size_per_partition,
            local_expert_offset=self.ep_rank * self.local_num_experts,
            local_num_experts=self.local_num_experts,
            routed_scaling_factor=None,
            routing_method_type=1,
            do_finalize=True,
            activation_type=activation_to_flashinfer_int(activation),
            output=output,
        )

_supports_parallel_config staticmethod

_supports_parallel_config(
    moe_parallel_config: FusedMoEParallelConfig,
) -> bool

The modular implementation supports all parallel configs.

Source code in vllm/model_executor/layers/fused_moe/experts/trtllm_nvfp4_moe.py
@staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
    """The modular implementation supports all parallel configs."""
    return True

TrtLlmNvFp4ExpertsMonolithic

Bases: TrtLlmNvFp4ExpertsBase, FusedMoEExpertsMonolithic

Monolithic version of the kernel (router + experts).

Source code in vllm/model_executor/layers/fused_moe/experts/trtllm_nvfp4_moe.py
class TrtLlmNvFp4ExpertsMonolithic(
    TrtLlmNvFp4ExpertsBase, mk.FusedMoEExpertsMonolithic
):
    """
    Monolithic version of the kernel (router + experts).
    """

    @staticmethod
    def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
        """The modular implementation should be used for the Dp/Ep or EPLB case."""
        return (
            not moe_parallel_config.use_all2all_kernels
            and not moe_parallel_config.enable_eplb
        )

    @staticmethod
    def _supports_routing_method(
        routing_method_type: RoutingMethodType,
        weight_key: QuantKey | None,
        activation_key: QuantKey | None,
    ) -> bool:
        # NOTE(rob): this is a conservative list.
        return routing_method_type in [
            RoutingMethodType.DeepSeekV3,
            RoutingMethodType.Renormalize,
            RoutingMethodType.RenormalizeNaive,
            RoutingMethodType.Llama4,
        ]

    @staticmethod
    def _supports_router_logits_dtype(
        router_logits_dtype: torch.dtype | None,
        routing_method: RoutingMethodType,
    ) -> bool:
        """
        The FlashInfer TRTLLM NvFp4 kernel expects bfloat16 router_logits by default.
        Only DeepSeekV3 routing supports float32 router_logits (which is converted
        internally in the kernel).
        """
        if router_logits_dtype == torch.float32:
            # Only DeepSeekV3 routing handles float32 logits
            # https://github.com/flashinfer-ai/flashinfer/issues/2469
            return routing_method == RoutingMethodType.DeepSeekV3
        return True

    def apply(
        self,
        hidden_states: torch.Tensor,
        w1: torch.Tensor,
        w2: torch.Tensor,
        router_logits: torch.Tensor,
        activation: MoEActivation,
        global_num_experts: int,
        expert_map: torch.Tensor | None,
        a1q_scale: torch.Tensor | None,
        apply_router_weight_on_input: bool,
        # grouped topk + fused topk bias parameters
        num_expert_group: int | None = None,
        e_score_correction_bias: torch.Tensor | None = None,
        routed_scaling_factor: float | None = None,
        topk_group: int | None = None,
    ) -> torch.Tensor:
        assert activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
        assert a1q_scale is not None
        assert self.quant_config.w1_scale is not None
        assert self.quant_config.w2_scale is not None
        assert (
            apply_router_weight_on_input
            and self.routing_method_type == RoutingMethodType.Llama4
        ) or (
            not apply_router_weight_on_input
            and self.routing_method_type != RoutingMethodType.Llama4
        )

        # Prepare routing bias into kernel format.
        routing_bias = e_score_correction_bias
        if routing_bias is not None:
            routing_bias = routing_bias.to(torch.bfloat16)
        router_logits = (
            router_logits.to(torch.float32)
            if self.routing_method_type == RoutingMethodType.DeepSeekV3
            else router_logits
        )

        # Invoke kernel.
        return flashinfer.fused_moe.trtllm_fp4_block_scale_moe(
            routing_logits=router_logits,
            routing_bias=routing_bias,
            hidden_states=hidden_states,
            hidden_states_scale=a1q_scale.view(torch.float8_e4m3fn).reshape(
                *hidden_states.shape[:-1], -1
            ),
            gemm1_weights=w1,
            gemm1_weights_scale=self.quant_config.w1_scale.view(torch.float8_e4m3fn),
            gemm1_bias=None,
            gemm1_alpha=None,
            gemm1_beta=None,
            gemm1_clamp_limit=None,
            gemm2_weights=w2,
            gemm2_weights_scale=self.quant_config.w2_scale.view(torch.float8_e4m3fn),
            gemm2_bias=None,
            output1_scale_scalar=self.g1_scale_c,
            output1_scale_gate_scalar=self.quant_config.g1_alphas,
            output2_scale_scalar=self.quant_config.g2_alphas,
            num_experts=global_num_experts,
            top_k=self.topk,
            n_group=(num_expert_group or 0),
            topk_group=(topk_group or 0),
            intermediate_size=self.intermediate_size_per_partition,
            local_expert_offset=self.ep_rank * self.local_num_experts,
            local_num_experts=self.local_num_experts,
            routed_scaling_factor=routed_scaling_factor,
            routing_method_type=self.routing_method_type,
            do_finalize=True,
            activation_type=activation_to_flashinfer_int(activation),
        )[0]

_supports_parallel_config staticmethod

_supports_parallel_config(
    moe_parallel_config: FusedMoEParallelConfig,
) -> bool

The modular implementation should be used for the Dp/Ep or EPLB case.

Source code in vllm/model_executor/layers/fused_moe/experts/trtllm_nvfp4_moe.py
@staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
    """The modular implementation should be used for the Dp/Ep or EPLB case."""
    return (
        not moe_parallel_config.use_all2all_kernels
        and not moe_parallel_config.enable_eplb
    )

_supports_router_logits_dtype staticmethod

_supports_router_logits_dtype(
    router_logits_dtype: dtype | None,
    routing_method: RoutingMethodType,
) -> bool

The FlashInfer TRTLLM NvFp4 kernel expects bfloat16 router_logits by default. Only DeepSeekV3 routing supports float32 router_logits (which is converted internally in the kernel).

Source code in vllm/model_executor/layers/fused_moe/experts/trtllm_nvfp4_moe.py
@staticmethod
def _supports_router_logits_dtype(
    router_logits_dtype: torch.dtype | None,
    routing_method: RoutingMethodType,
) -> bool:
    """
    The FlashInfer TRTLLM NvFp4 kernel expects bfloat16 router_logits by default.
    Only DeepSeekV3 routing supports float32 router_logits (which is converted
    internally in the kernel).
    """
    if router_logits_dtype == torch.float32:
        # Only DeepSeekV3 routing handles float32 logits
        # https://github.com/flashinfer-ai/flashinfer/issues/2469
        return routing_method == RoutingMethodType.DeepSeekV3
    return True