Skip to content

vllm.compilation.piecewise_backend

PiecewiseBackend

Source code in vllm/compilation/piecewise_backend.py
class PiecewiseBackend:
    def __init__(
        self,
        graph: fx.GraphModule | None,
        vllm_config: VllmConfig,
        piecewise_compile_index: int,
        total_piecewise_compiles: int,
        sym_shape_indices: list[int],
        vllm_backend: VllmBackend,
        returns_tuple: bool,
        compiled_runnables: dict[str, Callable[..., Any]] | None = None,
        submod_name: str = "",
    ):
        """
        The backend for piecewise compilation.
        It mainly handles the compilation of static shapes and
        dispatching based on runtime shape.

        We will compile `self.graph` once for the general shape,
        and then compile for different shapes specified in
        `compilation_config.compile_sizes`.

        This class supports two mutually exclusive modes:
        1. Compilation (graph is set, compiled_runnables is None):
           Used during initial compilation when we have the FX graph
           and need to compile it for each shape range.
        2. Precompilation (graph is None, compiled_runnables is set):
           Used when loading from cache/AOT artifacts where we already
           have pre-compiled callables and don't need the original graph.

        Exactly one of graph or compiled_runnables must be provided.
        """
        assert bool(graph is not None) ^ bool(compiled_runnables is not None), (
            "exactly one of graph and compiled_runnables should be set."
        )

        self.graph = graph
        self.vllm_config = vllm_config
        self.compilation_config = vllm_config.compilation_config
        self.piecewise_compile_index = piecewise_compile_index
        self.total_piecewise_compiles = total_piecewise_compiles
        self.vllm_backend = vllm_backend
        self.compiled_runnables = compiled_runnables
        self.submod_name = submod_name

        self.is_first_graph = piecewise_compile_index == 0
        self.is_last_graph = piecewise_compile_index == total_piecewise_compiles - 1

        self.is_full_graph = total_piecewise_compiles == 1
        self.is_encoder_compilation = vllm_backend.is_encoder

        self.compile_ranges = self.compilation_config.get_compile_ranges()
        if self.is_encoder_compilation:
            # For encoder compilation we use the max int32 value
            # to set the upper bound of the compile ranges
            max_int32 = 2**31 - 1
            last_compile_range = self.compile_ranges[-1]
            assert (
                last_compile_range.end
                == vllm_config.scheduler_config.max_num_batched_tokens
            )
            self.compile_ranges[-1] = Range(
                start=last_compile_range.start, end=max_int32
            )

        log_string = f"PiecewiseBackend: compile_ranges: {self.compile_ranges}"
        logger.debug_once(log_string)

        self.compile_sizes = self.compilation_config.compile_sizes
        log_string = f"PiecewiseBackend: compile_sizes: {self.compile_sizes}"
        logger.debug_once(log_string)

        self.sym_shape_indices = sym_shape_indices
        self.returns_tuple = returns_tuple

        # the entries for ranges that we need to either
        self.range_entries: dict[Range, RangeEntry] = {}

        # We only keep compilation management inside this class directly.
        if self.compile_sizes is not None:
            for size in self.compile_sizes:
                if isinstance(size, str):
                    assert size == "cudagraph_capture_sizes"
                    raise NotImplementedError(
                        "cudagraph_capture_sizes not supported in compile_sizes."
                        "This should be handled in `post_init_cudagraph_sizes`."
                    )
                else:
                    assert isinstance(size, int)
                    range = Range(start=size, end=size)
                    if range not in self.compile_ranges:
                        self.range_entries[range] = RangeEntry(
                            compile_range=range,
                        )

        for range in self.compile_ranges:
            self.range_entries[range] = RangeEntry(
                compile_range=range,
            )

        # Track whether we've logged the graph for this subgraph (only log once)
        self._graph_logged = False

        if self.graph is not None:
            self.compile_all_ranges()
        else:
            self.load_all_ranges()

    def get_compiled_graph_wrapper(
        self, compiled_graph: Callable[..., Any]
    ) -> Callable[..., Any]:
        def compiled_graph_wrapper(*args: Any) -> Any:
            graph_output = compiled_graph(*args)
            # unpack the tuple if needed
            # TODO(rzou): the implication is that we're not
            # reading the python bytecode correctly in vLLM?
            if self.returns_tuple or not isinstance(graph_output, (tuple, list)):
                return graph_output
            else:
                return graph_output[0]

        return compiled_graph_wrapper

    def to_bytes(self) -> dict[str, bytes]:
        class StandaloneCompiledArtifactsPickler(Pickler):
            def reducer_override(self, obj: object) -> Any:
                if isinstance(obj, CachingAutotuner):
                    obj.prepare_for_pickle()
                    return pickle.loads, (
                        pickle.dumps(
                            obj,
                        ),
                    )
                return NotImplemented

        def serialize(fn: Callable[..., Any]) -> bytes:
            assert hasattr(fn, "serialize"), "fn must have serialize method"
            with torch._functorch.config.patch("bundled_autograd_cache", True):
                entry = fn.serialize()

                f = io.BytesIO()
                StandaloneCompiledArtifactsPickler(f).dump(entry)
                result = f.getvalue()
            return result

        out = {}

        for range_key, entry in self.range_entries.items():
            if not entry.compiled:
                logger.debug(
                    "entry with range %s not compiled, so cannot get its bytes",
                    range_key,
                )
                continue
            if hasattr(entry.runnable, "serialize"):
                out[str(range_key)] = serialize(entry.runnable)

        return out

    def compile_all_ranges(self) -> None:
        """Compile all range entries for this piecewise subgraph up front."""
        assert self.graph is not None, (
            "Cannot compile without a graph. "
            "When loading from cache/AOT artifacts, "
            "compile_all_ranges should not be called."
        )

        for range_entry in self.range_entries.values():
            if range_entry.compiled:
                continue

            self._log_compile_start(range_entry.compile_range)

            if range_entry.compile_range.is_single_size():
                args_list = create_concrete_args(
                    self.graph, range_entry.compile_range.start
                )
            else:
                args_list = get_fake_args_from_graph(self.graph)

            # TODO(https://github.com/vllm-project/vllm/issues/35766)
            # Can we remove strict_autograd_cache and
            # force_non_lazy_backward_lowering overrides?
            # I added them explicitly because this is what they are
            # set to before the refactor
            # (https://github.com/vllm-project/vllm/pull/35472).
            # They affect the aotautograd cache key computation
            # but they shouldn't have any effect on the actual
            # compilation.
            config_patches = dict(
                bundled_autograd_cache=True,
                strict_autograd_cache=False,
            )
            if hasattr(torch._functorch.config, "force_non_lazy_backward_lowering"):
                config_patches["force_non_lazy_backward_lowering"] = False
            with torch._functorch.config.patch(**config_patches):
                range_entry.runnable = self.vllm_backend.compiler_manager.compile(
                    self.graph,
                    args_list,
                    self.vllm_backend.inductor_config,
                    self.compilation_config,
                    compile_range=range_entry.compile_range,
                    graph_index=self.piecewise_compile_index,
                    num_graphs=self.total_piecewise_compiles,
                )

            range_entry.compiled = True

    def _log_compile_start(self, compile_range: Range):
        """Log compilation event for TORCH_TRACE/tlparse."""
        is_cudagraph_size = (
            self.compile_sizes is not None and compile_range.start in self.compile_sizes
        )
        subgraph_index = self.piecewise_compile_index
        submod_name = self.submod_name
        trace_structured(
            "artifact",
            metadata_fn=lambda: {
                "name": "vllm_piecewise_compile_start",
                "encoding": "json",
            },
            payload_fn=lambda: json.dumps(
                {
                    "piecewise_index": subgraph_index,
                    "submod_name": submod_name,
                    "total_piecewise_compiles": self.total_piecewise_compiles,
                    "compile_range_start": compile_range.start,
                    "compile_range_end": compile_range.end,
                    "is_single_size": compile_range.is_single_size(),
                    "is_cudagraph_capture_size": is_cudagraph_size,
                }
            ),
        )

        # Log the subgraph graph dump only once per subgraph (not per size)
        # to reduce log file size. The graph code is the same for all sizes.
        if not self._graph_logged:
            self._graph_logged = True
            assert self.graph is not None
            trace_structured(
                "graph_dump",
                metadata_fn=lambda: {
                    "name": f"vllm_{submod_name}",
                },
                payload_fn=lambda: self.graph.print_readable(print_output=False),
            )

    def load_all_ranges(self) -> None:
        """Load all pre-compiled runnables for this piecewise subgraph.

        Called during warm start to wrap all cached compiled_runnables
        into range_entry.runnable up front, analogous to compile_all_ranges()
        for the cold start path.
        """
        assert self.compiled_runnables is not None, (
            "load_all_ranges should only be called when compiled_runnables "
            "is set (warm start / cache loading path)."
        )
        for range_entry in self.range_entries.values():
            if range_entry.compiled:
                continue
            key = str(range_entry.compile_range)
            assert key in self.compiled_runnables, (
                f"Missing compiled runnable for range {range_entry.compile_range}. "
                f"Available keys: {list(self.compiled_runnables.keys())}"
            )
            range_entry.runnable = self.get_compiled_graph_wrapper(
                self.compiled_runnables[key]
            )
            range_entry.compiled = True

    def _find_range_for_shape(self, runtime_shape: int) -> RangeEntry | None:
        # First we try to find the range entry for the concrete compile size
        # If not found, we search for the range entry
        # that contains the runtime shape.
        if self.compile_sizes is None:
            return None

        if runtime_shape in self.compile_sizes:
            return self.range_entries[Range(start=runtime_shape, end=runtime_shape)]
        else:
            for range in self.compile_ranges:
                if runtime_shape in range:
                    return self.range_entries[range]
        return None

    def __call__(self, *args: Any) -> Any:
        runtime_shape = args[self.sym_shape_indices[0]]
        range_entry = self._find_range_for_shape(runtime_shape)

        assert range_entry is not None, (
            f"Shape: {runtime_shape} out of considered ranges: {self.compile_ranges}"
        )
        assert range_entry.compiled, (
            "All ranges should be compiled or loaded up front in "
            "PiecewiseBackend.__init__. "
            f"range_entry={range_entry.compile_range}"
        )
        return range_entry.runnable(*args)

__init__

__init__(
    graph: GraphModule | None,
    vllm_config: VllmConfig,
    piecewise_compile_index: int,
    total_piecewise_compiles: int,
    sym_shape_indices: list[int],
    vllm_backend: VllmBackend,
    returns_tuple: bool,
    compiled_runnables: dict[str, Callable[..., Any]]
    | None = None,
    submod_name: str = "",
)

The backend for piecewise compilation. It mainly handles the compilation of static shapes and dispatching based on runtime shape.

We will compile self.graph once for the general shape, and then compile for different shapes specified in compilation_config.compile_sizes.

This class supports two mutually exclusive modes: 1. Compilation (graph is set, compiled_runnables is None): Used during initial compilation when we have the FX graph and need to compile it for each shape range. 2. Precompilation (graph is None, compiled_runnables is set): Used when loading from cache/AOT artifacts where we already have pre-compiled callables and don't need the original graph.

Exactly one of graph or compiled_runnables must be provided.

Source code in vllm/compilation/piecewise_backend.py
def __init__(
    self,
    graph: fx.GraphModule | None,
    vllm_config: VllmConfig,
    piecewise_compile_index: int,
    total_piecewise_compiles: int,
    sym_shape_indices: list[int],
    vllm_backend: VllmBackend,
    returns_tuple: bool,
    compiled_runnables: dict[str, Callable[..., Any]] | None = None,
    submod_name: str = "",
):
    """
    The backend for piecewise compilation.
    It mainly handles the compilation of static shapes and
    dispatching based on runtime shape.

    We will compile `self.graph` once for the general shape,
    and then compile for different shapes specified in
    `compilation_config.compile_sizes`.

    This class supports two mutually exclusive modes:
    1. Compilation (graph is set, compiled_runnables is None):
       Used during initial compilation when we have the FX graph
       and need to compile it for each shape range.
    2. Precompilation (graph is None, compiled_runnables is set):
       Used when loading from cache/AOT artifacts where we already
       have pre-compiled callables and don't need the original graph.

    Exactly one of graph or compiled_runnables must be provided.
    """
    assert bool(graph is not None) ^ bool(compiled_runnables is not None), (
        "exactly one of graph and compiled_runnables should be set."
    )

    self.graph = graph
    self.vllm_config = vllm_config
    self.compilation_config = vllm_config.compilation_config
    self.piecewise_compile_index = piecewise_compile_index
    self.total_piecewise_compiles = total_piecewise_compiles
    self.vllm_backend = vllm_backend
    self.compiled_runnables = compiled_runnables
    self.submod_name = submod_name

    self.is_first_graph = piecewise_compile_index == 0
    self.is_last_graph = piecewise_compile_index == total_piecewise_compiles - 1

    self.is_full_graph = total_piecewise_compiles == 1
    self.is_encoder_compilation = vllm_backend.is_encoder

    self.compile_ranges = self.compilation_config.get_compile_ranges()
    if self.is_encoder_compilation:
        # For encoder compilation we use the max int32 value
        # to set the upper bound of the compile ranges
        max_int32 = 2**31 - 1
        last_compile_range = self.compile_ranges[-1]
        assert (
            last_compile_range.end
            == vllm_config.scheduler_config.max_num_batched_tokens
        )
        self.compile_ranges[-1] = Range(
            start=last_compile_range.start, end=max_int32
        )

    log_string = f"PiecewiseBackend: compile_ranges: {self.compile_ranges}"
    logger.debug_once(log_string)

    self.compile_sizes = self.compilation_config.compile_sizes
    log_string = f"PiecewiseBackend: compile_sizes: {self.compile_sizes}"
    logger.debug_once(log_string)

    self.sym_shape_indices = sym_shape_indices
    self.returns_tuple = returns_tuple

    # the entries for ranges that we need to either
    self.range_entries: dict[Range, RangeEntry] = {}

    # We only keep compilation management inside this class directly.
    if self.compile_sizes is not None:
        for size in self.compile_sizes:
            if isinstance(size, str):
                assert size == "cudagraph_capture_sizes"
                raise NotImplementedError(
                    "cudagraph_capture_sizes not supported in compile_sizes."
                    "This should be handled in `post_init_cudagraph_sizes`."
                )
            else:
                assert isinstance(size, int)
                range = Range(start=size, end=size)
                if range not in self.compile_ranges:
                    self.range_entries[range] = RangeEntry(
                        compile_range=range,
                    )

    for range in self.compile_ranges:
        self.range_entries[range] = RangeEntry(
            compile_range=range,
        )

    # Track whether we've logged the graph for this subgraph (only log once)
    self._graph_logged = False

    if self.graph is not None:
        self.compile_all_ranges()
    else:
        self.load_all_ranges()

_log_compile_start

_log_compile_start(compile_range: Range)

Log compilation event for TORCH_TRACE/tlparse.

Source code in vllm/compilation/piecewise_backend.py
def _log_compile_start(self, compile_range: Range):
    """Log compilation event for TORCH_TRACE/tlparse."""
    is_cudagraph_size = (
        self.compile_sizes is not None and compile_range.start in self.compile_sizes
    )
    subgraph_index = self.piecewise_compile_index
    submod_name = self.submod_name
    trace_structured(
        "artifact",
        metadata_fn=lambda: {
            "name": "vllm_piecewise_compile_start",
            "encoding": "json",
        },
        payload_fn=lambda: json.dumps(
            {
                "piecewise_index": subgraph_index,
                "submod_name": submod_name,
                "total_piecewise_compiles": self.total_piecewise_compiles,
                "compile_range_start": compile_range.start,
                "compile_range_end": compile_range.end,
                "is_single_size": compile_range.is_single_size(),
                "is_cudagraph_capture_size": is_cudagraph_size,
            }
        ),
    )

    # Log the subgraph graph dump only once per subgraph (not per size)
    # to reduce log file size. The graph code is the same for all sizes.
    if not self._graph_logged:
        self._graph_logged = True
        assert self.graph is not None
        trace_structured(
            "graph_dump",
            metadata_fn=lambda: {
                "name": f"vllm_{submod_name}",
            },
            payload_fn=lambda: self.graph.print_readable(print_output=False),
        )

compile_all_ranges

compile_all_ranges() -> None

Compile all range entries for this piecewise subgraph up front.

Source code in vllm/compilation/piecewise_backend.py
def compile_all_ranges(self) -> None:
    """Compile all range entries for this piecewise subgraph up front."""
    assert self.graph is not None, (
        "Cannot compile without a graph. "
        "When loading from cache/AOT artifacts, "
        "compile_all_ranges should not be called."
    )

    for range_entry in self.range_entries.values():
        if range_entry.compiled:
            continue

        self._log_compile_start(range_entry.compile_range)

        if range_entry.compile_range.is_single_size():
            args_list = create_concrete_args(
                self.graph, range_entry.compile_range.start
            )
        else:
            args_list = get_fake_args_from_graph(self.graph)

        # TODO(https://github.com/vllm-project/vllm/issues/35766)
        # Can we remove strict_autograd_cache and
        # force_non_lazy_backward_lowering overrides?
        # I added them explicitly because this is what they are
        # set to before the refactor
        # (https://github.com/vllm-project/vllm/pull/35472).
        # They affect the aotautograd cache key computation
        # but they shouldn't have any effect on the actual
        # compilation.
        config_patches = dict(
            bundled_autograd_cache=True,
            strict_autograd_cache=False,
        )
        if hasattr(torch._functorch.config, "force_non_lazy_backward_lowering"):
            config_patches["force_non_lazy_backward_lowering"] = False
        with torch._functorch.config.patch(**config_patches):
            range_entry.runnable = self.vllm_backend.compiler_manager.compile(
                self.graph,
                args_list,
                self.vllm_backend.inductor_config,
                self.compilation_config,
                compile_range=range_entry.compile_range,
                graph_index=self.piecewise_compile_index,
                num_graphs=self.total_piecewise_compiles,
            )

        range_entry.compiled = True

load_all_ranges

load_all_ranges() -> None

Load all pre-compiled runnables for this piecewise subgraph.

Called during warm start to wrap all cached compiled_runnables into range_entry.runnable up front, analogous to compile_all_ranges() for the cold start path.

Source code in vllm/compilation/piecewise_backend.py
def load_all_ranges(self) -> None:
    """Load all pre-compiled runnables for this piecewise subgraph.

    Called during warm start to wrap all cached compiled_runnables
    into range_entry.runnable up front, analogous to compile_all_ranges()
    for the cold start path.
    """
    assert self.compiled_runnables is not None, (
        "load_all_ranges should only be called when compiled_runnables "
        "is set (warm start / cache loading path)."
    )
    for range_entry in self.range_entries.values():
        if range_entry.compiled:
            continue
        key = str(range_entry.compile_range)
        assert key in self.compiled_runnables, (
            f"Missing compiled runnable for range {range_entry.compile_range}. "
            f"Available keys: {list(self.compiled_runnables.keys())}"
        )
        range_entry.runnable = self.get_compiled_graph_wrapper(
            self.compiled_runnables[key]
        )
        range_entry.compiled = True

create_concrete_args

create_concrete_args(
    graph: GraphModule, size: int
) -> list[Any]

Create example inputs with symbolic dims replaced by a concrete size.

Used for single-size eager compilation where we need concrete-shaped inputs but don't have real runtime tensors yet.

Source code in vllm/compilation/piecewise_backend.py
def create_concrete_args(graph: fx.GraphModule, size: int) -> list[Any]:
    """Create example inputs with symbolic dims replaced by a concrete size.

    Used for single-size eager compilation where we need concrete-shaped
    inputs but don't have real runtime tensors yet.
    """
    from torch._prims_common import compute_required_storage_length
    from torch.fx.experimental.symbolic_shapes import is_symbolic

    def concretize(sym_val: Any) -> int:
        """Replace all symbolic variables in a SymInt expression with size."""
        if not is_symbolic(sym_val):
            return int(sym_val)
        expr = sym_val.node.expr
        return int(expr.subs({s: size for s in expr.free_symbols}))

    args: list[Any] = []
    for node in graph.graph.nodes:
        if node.op != "placeholder":
            break
        val = node.meta["example_value"]
        if isinstance(val, torch.SymInt):
            args.append(concretize(val))
        elif isinstance(val, torch.Tensor):
            new_shape = tuple(concretize(d) for d in val.shape)
            new_strides = tuple(concretize(s) for s in val.stride())
            new_storage_offset = concretize(val.storage_offset())
            needed_size = compute_required_storage_length(
                new_shape, new_strides, new_storage_offset
            )
            t = torch.empty(needed_size, dtype=val.dtype, device=val.device)
            t = t.as_strided(new_shape, new_strides, new_storage_offset)
            args.append(t)
        else:
            args.append(val)
    return args

get_fake_args_from_graph

get_fake_args_from_graph(graph: GraphModule) -> list[Any]

Get fake args directly from graph placeholder nodes.

Source code in vllm/compilation/piecewise_backend.py
def get_fake_args_from_graph(graph: fx.GraphModule) -> list[Any]:
    """Get fake args directly from graph placeholder nodes."""
    fake_args = []
    for node in graph.graph.nodes:
        if node.op == "placeholder":
            fake_args.append(node.meta["example_value"])
        else:
            break
    return fake_args