Skip to content

vllm.model_executor.layers.fla.ops.kda

FusedRMSNormGated

Bases: CustomOp

Source code in vllm/model_executor/layers/fla/ops/kda.py
@CustomOp.register("fused_rms_norm_gated")
class FusedRMSNormGated(CustomOp):
    def __init__(
        self,
        hidden_size: int,
        elementwise_affine: bool = True,
        eps: float = 1e-5,
        activation: str = "swish",
        device: torch.device | None = None,
        dtype: torch.dtype | None = None,
    ) -> None:
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()

        self.hidden_size = hidden_size
        self.elementwise_affine = elementwise_affine
        self.eps = eps
        self.activation = activation

        if self.activation not in ["swish", "silu", "sigmoid"]:
            raise ValueError(f"Unsupported activation: {self.activation}")

        if elementwise_affine:
            self.weight = nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
        else:
            self.register_parameter("weight", None)
        self.register_parameter("bias", None)

    def forward_native(
        self,
        x: torch.Tensor,
        g: torch.Tensor,
        residual: torch.Tensor | None = None,
        prenorm: bool = False,
        residual_in_fp32: bool = False,
    ) -> torch.Tensor:
        """Decomposed PyTorch ops for torch.compile/inductor fusion."""
        # TODO(https://github.com/vllm-project/vllm/issues/36175): implement
        # native residual/prenorm path and unify with RMSNormGated.
        # For now, fall back to the triton kernel.
        if residual is not None or prenorm:
            return self.forward_cuda(x, g, residual, prenorm, residual_in_fp32)
        x_float = x.float()
        variance = x_float.pow(2).mean(dim=-1, keepdim=True)
        x_normed = x_float * torch.rsqrt(variance + self.eps)
        if self.weight is not None:
            x_normed = x_normed * self.weight.float()
        g_float = g.float()
        if self.activation in ("swish", "silu"):
            out = x_normed * g_float * torch.sigmoid(g_float)
        else:  # sigmoid
            out = x_normed * torch.sigmoid(g_float)
        return out.to(x.dtype)

    def forward_cuda(
        self,
        x: torch.Tensor,
        g: torch.Tensor,
        residual: torch.Tensor | None = None,
        prenorm: bool = False,
        residual_in_fp32: bool = False,
    ) -> torch.Tensor:
        return rms_norm_gated(
            x,
            g,
            self.weight,
            self.bias,
            self.activation,
            residual=residual,
            eps=self.eps,
            prenorm=prenorm,
            residual_in_fp32=residual_in_fp32,
        )

forward_native

forward_native(
    x: Tensor,
    g: Tensor,
    residual: Tensor | None = None,
    prenorm: bool = False,
    residual_in_fp32: bool = False,
) -> Tensor

Decomposed PyTorch ops for torch.compile/inductor fusion.

Source code in vllm/model_executor/layers/fla/ops/kda.py
def forward_native(
    self,
    x: torch.Tensor,
    g: torch.Tensor,
    residual: torch.Tensor | None = None,
    prenorm: bool = False,
    residual_in_fp32: bool = False,
) -> torch.Tensor:
    """Decomposed PyTorch ops for torch.compile/inductor fusion."""
    # TODO(https://github.com/vllm-project/vllm/issues/36175): implement
    # native residual/prenorm path and unify with RMSNormGated.
    # For now, fall back to the triton kernel.
    if residual is not None or prenorm:
        return self.forward_cuda(x, g, residual, prenorm, residual_in_fp32)
    x_float = x.float()
    variance = x_float.pow(2).mean(dim=-1, keepdim=True)
    x_normed = x_float * torch.rsqrt(variance + self.eps)
    if self.weight is not None:
        x_normed = x_normed * self.weight.float()
    g_float = g.float()
    if self.activation in ("swish", "silu"):
        out = x_normed * g_float * torch.sigmoid(g_float)
    else:  # sigmoid
        out = x_normed * torch.sigmoid(g_float)
    return out.to(x.dtype)

chunk_kda_scaled_dot_kkt_fwd

chunk_kda_scaled_dot_kkt_fwd(
    q: Tensor,
    k: Tensor,
    gk: Tensor | None = None,
    beta: Tensor | None = None,
    scale: float | None = None,
    cu_seqlens: LongTensor | None = None,
    chunk_size: int = 64,
    output_dtype: dtype = float32,
) -> tuple[Tensor, Tensor]

Compute beta * K * K^T.

Parameters:

Name Type Description Default
k Tensor

The key tensor of shape [B, T, H, K].

required
beta Tensor

The beta tensor of shape [B, T, H].

None
gk Tensor

The cumulative sum of the gate tensor of shape [B, T, H, K] applied to the key tensor. Default: None.

None
cu_seqlens LongTensor

The cumulative sequence lengths of the input tensor. Default: None

None
chunk_size int

The chunk size. Default: 64.

64
output_dtype dtype

The dtype of the output tensor. Default: torch.float32

float32

Returns:

Type Description
tuple[Tensor, Tensor]

beta * K * K^T of shape [B, T, H, BT] where BT is the chunk size.

Source code in vllm/model_executor/layers/fla/ops/kda.py
def chunk_kda_scaled_dot_kkt_fwd(
    q: torch.Tensor,
    k: torch.Tensor,
    gk: torch.Tensor | None = None,
    beta: torch.Tensor | None = None,
    scale: float | None = None,
    cu_seqlens: torch.LongTensor | None = None,
    chunk_size: int = 64,
    output_dtype: torch.dtype = torch.float32,
) -> tuple[torch.Tensor, torch.Tensor]:
    r"""
    Compute beta * K * K^T.

    Args:
        k (torch.Tensor):
            The key tensor of shape `[B, T, H, K]`.
        beta (torch.Tensor):
            The beta tensor of shape `[B, T, H]`.
        gk (torch.Tensor):
            The cumulative sum of the gate tensor of shape `[B, T, H, K]` applied to the key tensor. Default: `None`.
        cu_seqlens (torch.LongTensor):
            The cumulative sequence lengths of the input tensor.
            Default: None
        chunk_size (int):
            The chunk size. Default: 64.
        output_dtype (torch.dtype):
            The dtype of the output tensor. Default: `torch.float32`

    Returns:
        beta * K * K^T of shape `[B, T, H, BT]` where `BT` is the chunk size.
    """
    B, T, H, K = k.shape
    assert K <= 256
    BT = chunk_size
    chunk_indices = (
        prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
    )
    NT = cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)

    BC = min(16, BT)
    NC = cdiv(BT, BC)
    BK = max(next_power_of_2(K), 16)
    A = torch.zeros(B, T, H, BT, device=k.device, dtype=output_dtype)
    Aqk = torch.zeros(B, T, H, BT, device=k.device, dtype=output_dtype)
    grid = (NT, NC * NC, B * H)
    chunk_kda_scaled_dot_kkt_fwd_kernel_intra_sub_inter[grid](
        q=q,
        k=k,
        g=gk,
        beta=beta,
        A=A,
        Aqk=Aqk,
        scale=scale,
        cu_seqlens=cu_seqlens,
        chunk_indices=chunk_indices,
        T=T,
        H=H,
        K=K,
        BT=BT,
        BC=BC,
        NC=NC,
    )

    grid = (NT, NC, B * H)
    chunk_kda_scaled_dot_kkt_fwd_kernel_intra_sub_intra[grid](
        q=q,
        k=k,
        g=gk,
        beta=beta,
        A=A,
        Aqk=Aqk,
        scale=scale,
        cu_seqlens=cu_seqlens,
        chunk_indices=chunk_indices,
        T=T,
        H=H,
        K=K,
        BT=BT,
        BC=BC,
        BK=BK,
    )
    return A, Aqk

fused_kda_gate

fused_kda_gate(
    g: Tensor,
    A: Tensor,
    head_k_dim: int,
    g_bias: Tensor | None = None,
    beta: float = 1.0,
    threshold: float = 20.0,
) -> Tensor
Forward pass for KDA gate

input g: [..., H*D] param A: [H] or [1, 1, H, 1] beta: softplus beta parameter threshold: softplus threshold parameter return : [..., H, D]

Source code in vllm/model_executor/layers/fla/ops/kda.py
def fused_kda_gate(
    g: torch.Tensor,
    A: torch.Tensor,
    head_k_dim: int,
    g_bias: torch.Tensor | None = None,
    beta: float = 1.0,
    threshold: float = 20.0,
) -> torch.Tensor:
    """
    Forward pass for KDA gate:
      input g: [..., H*D]
      param A: [H] or [1, 1, H, 1]
      beta: softplus beta parameter
      threshold: softplus threshold parameter
      return  : [..., H, D]
    """
    orig_shape = g.shape[:-1]

    g = g.view(-1, g.shape[-1])
    T = g.shape[0]
    HD = g.shape[1]
    H = A.numel()
    assert H * head_k_dim == HD

    y = torch.empty_like(g, dtype=torch.float32)

    def grid(meta):
        return (cdiv(T, meta["BT"]), H)

    kda_gate_fwd_kernel[grid](
        g,
        A,
        y,
        g_bias,
        beta,
        threshold,
        T,
        H,
        head_k_dim,
        BD=next_power_of_2(head_k_dim),
        HAS_BIAS=g_bias is not None,
    )

    y = y.view(*orig_shape, H, head_k_dim)
    return y