vllm.model_executor.layers.fla.ops.layernorm_guard ¶
LayerNormGated ¶
Bases: Module
Source code in vllm/model_executor/layers/fla/ops/layernorm_guard.py
__init__ ¶
__init__(
hidden_size,
eps: float = 1e-05,
group_size: int | None = None,
norm_before_gate: bool = True,
device: device | None = None,
dtype: dtype | None = None,
)
If group_size is not None, we do GroupNorm with each group having group_size elements. group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group).
Source code in vllm/model_executor/layers/fla/ops/layernorm_guard.py
forward ¶
If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))
Source code in vllm/model_executor/layers/fla/ops/layernorm_guard.py
RMSNormGated ¶
Bases: Module
Source code in vllm/model_executor/layers/fla/ops/layernorm_guard.py
__init__ ¶
__init__(
hidden_size,
eps: float = 1e-05,
group_size: int | None = None,
norm_before_gate: bool = False,
device: device | None = None,
dtype: dtype | None = None,
activation: str = "swish",
)
If group_size is not None, we do GroupNorm with each group having group_size elements. group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group).
Source code in vllm/model_executor/layers/fla/ops/layernorm_guard.py
forward ¶
If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))
Source code in vllm/model_executor/layers/fla/ops/layernorm_guard.py
_layer_norm_fn_impl ¶
_layer_norm_fn_impl(
x,
weight,
bias,
z=None,
eps=1e-06,
group_size=None,
norm_before_gate=True,
is_rms_norm=False,
activation: str = "swish",
)
Triton layer/RMS norm with optional gating.
If z is not None, computes norm(x) * silu(z) when norm_before_gate, else norm(x * silu(z)).
This calls the triton kernel directly. The original code wrapped this in a torch.autograd.Function (LayerNormFn) to save tensors for a backward pass, but vLLM is inference-only so there is no backward pass. The autograd wrapper also prevented torch.compile/dynamo from tracing through the function due to its @staticmethod forward.