Skip to content

vllm.model_executor.models.qwen3_vl

Inference-only Qwen3VL model compatible with HuggingFace weights.

Qwen3VLForConditionalGeneration

Bases: Module, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE, SupportsEagle3, SupportsMultiModalPruning

Source code in vllm/model_executor/models/qwen3_vl.py
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
1956
1957
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990
1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
2001
2002
2003
2004
2005
2006
2007
2008
2009
2010
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
2021
2022
2023
2024
2025
2026
2027
2028
2029
2030
2031
2032
2033
2034
2035
2036
2037
2038
2039
2040
2041
2042
2043
2044
2045
2046
2047
2048
2049
2050
2051
2052
2053
2054
2055
2056
2057
2058
2059
2060
2061
2062
2063
2064
2065
2066
2067
2068
2069
2070
2071
2072
2073
2074
2075
2076
2077
2078
2079
2080
2081
2082
2083
2084
2085
2086
2087
2088
2089
2090
2091
2092
2093
2094
2095
2096
2097
2098
2099
2100
2101
2102
2103
2104
2105
2106
2107
2108
2109
2110
2111
2112
2113
2114
2115
2116
2117
2118
2119
2120
2121
2122
2123
2124
2125
2126
2127
2128
2129
2130
2131
2132
2133
2134
2135
2136
2137
2138
2139
2140
2141
2142
2143
2144
2145
2146
2147
2148
2149
2150
2151
2152
2153
2154
2155
2156
2157
2158
2159
2160
2161
2162
2163
2164
2165
2166
2167
2168
2169
2170
2171
2172
2173
2174
2175
2176
2177
2178
2179
2180
2181
2182
2183
2184
2185
2186
2187
2188
2189
2190
2191
2192
2193
2194
2195
2196
2197
2198
2199
2200
2201
2202
2203
2204
2205
2206
2207
2208
2209
2210
2211
2212
2213
2214
2215
2216
2217
2218
2219
2220
2221
2222
2223
2224
2225
2226
2227
2228
2229
2230
2231
2232
2233
2234
2235
2236
2237
2238
2239
2240
2241
2242
2243
2244
2245
2246
2247
2248
2249
2250
2251
2252
2253
2254
2255
2256
2257
2258
2259
2260
2261
2262
2263
2264
2265
2266
2267
2268
2269
2270
2271
2272
2273
2274
2275
2276
2277
2278
2279
2280
2281
2282
2283
2284
2285
2286
2287
2288
2289
2290
2291
2292
2293
2294
2295
2296
2297
2298
2299
2300
2301
2302
2303
2304
2305
2306
2307
2308
2309
2310
2311
2312
2313
2314
2315
2316
2317
2318
2319
2320
2321
2322
2323
2324
2325
2326
2327
2328
2329
2330
2331
2332
2333
2334
2335
2336
2337
2338
2339
2340
2341
2342
2343
2344
2345
2346
2347
2348
2349
2350
2351
2352
2353
2354
2355
2356
2357
2358
2359
2360
2361
2362
2363
2364
2365
2366
2367
2368
2369
2370
2371
2372
2373
2374
2375
2376
2377
2378
2379
2380
2381
2382
2383
2384
2385
2386
2387
2388
2389
2390
2391
2392
2393
2394
2395
2396
2397
2398
2399
2400
2401
2402
2403
2404
2405
2406
2407
2408
2409
2410
2411
2412
2413
2414
2415
2416
2417
2418
2419
2420
2421
2422
2423
2424
2425
2426
2427
2428
2429
2430
2431
2432
@MULTIMODAL_REGISTRY.register_processor(
    Qwen3VLMultiModalProcessor,
    info=Qwen3VLProcessingInfo,
    dummy_inputs=Qwen3VLDummyInputsBuilder,
)
class Qwen3VLForConditionalGeneration(
    nn.Module,
    SupportsMultiModal,
    SupportsLoRA,
    SupportsPP,
    SupportsMRoPE,
    SupportsEagle3,
    SupportsMultiModalPruning,
):
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
        "qkv": ["qkv"],  # For vision tower's already-packed QKV
    }

    supports_encoder_tp_data = True

    # To ensure correct weight loading and mapping.
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            "model.visual.": "visual.",
            "lm_head.": "language_model.lm_head.",
            "model.language_model.": "language_model.model.",
        }
    )

    @classmethod
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
        if modality.startswith("image"):
            return "<|vision_start|><|image_pad|><|vision_end|>"
        if modality.startswith("video"):
            return "<|vision_start|><|video_pad|><|vision_end|>"

        raise ValueError("Only image or video modality is supported")

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "model"):
        super().__init__()
        config: Qwen3VLConfig = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config

        self.config = config
        self._tokenizer = cached_tokenizer_from_config(vllm_config.model_config)
        self.multimodal_config = multimodal_config
        self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
        self.video_pruning_rate = multimodal_config.video_pruning_rate
        self.is_multimodal_pruning_enabled = (
            multimodal_config.is_multimodal_pruning_enabled()
        )

        self.use_deepstack = hasattr(config.vision_config, "deepstack_visual_indexes")
        self.deepstack_num_level = (
            len(config.vision_config.deepstack_visual_indexes)
            if self.use_deepstack
            else 0
        )
        self.visual_dim = config.vision_config.out_hidden_size
        self.multiscale_dim = self.visual_dim * self.deepstack_num_level

        with self._mark_tower_model(vllm_config, {"image", "video"}):
            self.visual = Qwen3_VisionTransformer(
                config.vision_config,
                norm_eps=getattr(config, "rms_norm_eps", 1e-6),
                quant_config=quant_config,
                prefix=maybe_prefix(prefix, "visual"),
            )

            # register buffer for deepstack
            if self.use_deepstack:
                self.deepstack_input_embeds = [
                    torch.zeros(
                        vllm_config.scheduler_config.max_num_batched_tokens,
                        config.text_config.hidden_size,
                    )
                    for _ in range(self.deepstack_num_level)
                ]

        with self._mark_language_model(vllm_config):
            self.language_model = Qwen3LLMForCausalLM(
                vllm_config=vllm_config.with_hf_config(config.text_config),
                prefix=maybe_prefix(prefix, "language_model"),
            )

        if not get_pp_group().is_first_rank and hasattr(
            config.vision_config, "deepstack_visual_indexes"
        ):
            assert self.language_model.start_layer >= len(
                config.vision_config.deepstack_visual_indexes
            ), (
                "start_layer should be greater than or equal to "
                "len(deepstack_visual_indexes)"
            )

        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors
        )

    def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
        self.language_model.model.aux_hidden_state_layers = layers

    def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
        num_layers = len(self.language_model.model.layers)
        return (2, num_layers // 2, num_layers - 3)

    def _get_deepstack_input_embeds(
        self,
        num_tokens: int,
    ) -> IntermediateTensors | None:
        if not getattr(self, "deepstack_input_embeds", None):
            return None  # If vision tower is skipped

        # get deepstack_input_embeds from buffer, and clear the buffer
        return IntermediateTensors(
            {
                f"deepstack_input_embeds_{idx}": self.deepstack_input_embeds[idx][
                    :num_tokens
                ]
                for idx in range(self.deepstack_num_level)
            }
        )

    def _set_deepstack_input_embeds(self, deepstack_input_embeds: torch.Tensor) -> None:
        if not getattr(self, "deepstack_input_embeds", None):
            return

        # set deepstack_input_embeds to buffer
        num_tokens = deepstack_input_embeds.size(1)
        if num_tokens > self.deepstack_input_embeds[0].size(0):
            self.deepstack_input_embeds = [
                torch.zeros(
                    num_tokens,
                    self.config.text_config.hidden_size,
                    device=self.deepstack_input_embeds[0].device,
                    dtype=self.deepstack_input_embeds[0].dtype,
                )
                for _ in range(self.deepstack_num_level)
            ]
        for idx in range(self.deepstack_num_level):
            self.deepstack_input_embeds[idx][:num_tokens].copy_(
                deepstack_input_embeds[idx]
            )

    def _clear_deepstack_input_embeds(self, num_tokens: int) -> None:
        if not getattr(self, "deepstack_input_embeds", None):
            return

        # clear deepstack_input_embeds in buffer
        if num_tokens > 0:
            for idx in range(self.deepstack_num_level):
                self.deepstack_input_embeds[idx][:num_tokens].zero_()

    def _parse_and_validate_image_input(
        self, **kwargs: object
    ) -> Qwen2_5_VLImageInputs | None:
        pixel_values = kwargs.pop("pixel_values", None)
        image_embeds = kwargs.pop("image_embeds", None)
        image_grid_thw = kwargs.pop("image_grid_thw", None)

        if pixel_values is None and image_embeds is None:
            return None

        if pixel_values is not None:
            return Qwen2_5_VLImagePixelInputs(
                type="pixel_values",
                pixel_values=pixel_values,
                image_grid_thw=image_grid_thw,
            )

        if image_embeds is not None:
            return Qwen2_5_VLImageEmbeddingInputs(
                type="image_embeds",
                image_embeds=image_embeds,
                image_grid_thw=image_grid_thw,
            )

    def _parse_and_validate_video_input(
        self, **kwargs: object
    ) -> Qwen2_5_VLVideoInputs | None:
        pixel_values_videos = kwargs.pop("pixel_values_videos", None)
        video_embeds = kwargs.pop("video_embeds", None)
        video_grid_thw = kwargs.pop("video_grid_thw", None)
        second_per_grid_ts = kwargs.pop("second_per_grid_ts", None)
        timestamps = kwargs.pop("timestamps", None)

        if pixel_values_videos is None and video_embeds is None:
            return None

        if pixel_values_videos is not None:
            return Qwen2_5_VLVideoPixelInputs(
                type="pixel_values_videos",
                pixel_values_videos=pixel_values_videos,
                video_grid_thw=video_grid_thw,
                second_per_grid_ts=second_per_grid_ts,
                timestamps=timestamps,
            )

        if video_embeds is not None:
            return Qwen2_5_VLVideoEmbeddingInputs(
                type="video_embeds",
                video_embeds=video_embeds,
                video_grid_thw=video_grid_thw,
                timestamps=timestamps,
            )

    def _process_image_input(
        self, image_input: Qwen2_5_VLImageInputs
    ) -> tuple[torch.Tensor, ...]:
        grid_thw = image_input["image_grid_thw"]
        assert grid_thw.ndim == 2

        if image_input["type"] == "image_embeds":
            image_embeds = image_input["image_embeds"].type(self.visual.dtype)
        else:
            pixel_values = image_input["pixel_values"].type(self.visual.dtype)
            if self.use_data_parallel:
                return run_dp_sharded_mrope_vision_model(
                    self.visual, pixel_values, grid_thw.tolist(), rope_type="rope_3d"
                )
            else:
                image_embeds = self.visual(pixel_values, grid_thw=grid_thw)

        # Split concatenated embeddings for each image item.
        merge_size = self.visual.spatial_merge_size
        sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist()
        return image_embeds.split(sizes)

    def _process_video_input(
        self, video_input: Qwen2_5_VLVideoInputs
    ) -> tuple[torch.Tensor, ...]:
        grid_thw = video_input["video_grid_thw"]
        assert grid_thw.ndim == 2

        if video_input["type"] == "video_embeds":
            video_embeds = video_input["video_embeds"].type(self.visual.dtype)
        else:
            pixel_values_videos = video_input["pixel_values_videos"].type(
                self.visual.dtype
            )
            if self.use_data_parallel:
                grid_thw_list = grid_thw.tolist()
                return run_dp_sharded_mrope_vision_model(
                    self.visual, pixel_values_videos, grid_thw_list, rope_type="rope_3d"
                )
            else:
                video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw)

        # Split concatenated embeddings for each video item.
        merge_size = self.visual.spatial_merge_size
        sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist()
        return video_embeds.split(sizes)

    def _postprocess_image_embeds_evs(
        self,
        image_embeds_split: tuple[torch.Tensor, ...],
        image_input: Qwen2_5_VLImageInputs,
    ) -> tuple[torch.Tensor, ...]:
        """
        Append mrope positions for each for images.
        This is necessary to recover correct mrope
        positions after video pruning

        Args:
            image_embeds_split: Tuple of image embeddings for
                each image item.
            image_input: Image input data.

        Returns:
            Tuple of image embeddings for each image item.
            Resulting embeddings will have extra 5 channels for
            computed mrope positions, consistent with video embeddings.
        """
        if self.is_multimodal_pruning_enabled:
            merge_size = self.visual.spatial_merge_size
            grid_thw = image_input["image_grid_thw"]
            grid_thw_list = grid_thw.tolist()
            image_embeds_out = []
            for emb, size in zip(image_embeds_split, grid_thw_list):
                positions = compute_mrope_for_media(size, merge_size).to(emb.device)
                positions = torch.cat(
                    [
                        positions,
                        torch.zeros_like(
                            positions[:, 0:1]
                        ),  # Dummy extra fifth channel
                    ],
                    dim=1,
                )
                emb = torch.cat([emb, positions], dim=1)
                image_embeds_out.append(emb)
            image_embeds_split = tuple(image_embeds_out)
        return image_embeds_split

    def _postprocess_video_embeds_evs(
        self,
        video_embeds_split: tuple[torch.Tensor, ...],
        video_input: Qwen2_5_VLVideoInputs,
    ) -> tuple[torch.Tensor, ...]:
        """
        Prunes video embeddings via Efficient Video Sampling (EVS)
        and then appends mrope positions for each retained embeddings

        Args:
            video_embeds_split: Tuple of video embeddings for each video item.
            video_input: Video input data.

        Returns:
            Tuple of video embeddings for each video item.
            Resulting embeddings will have extra 5 channels for computed mrope
            positions, and whether the index corresponds to a video embedding.
        """
        grid_thw = video_input["video_grid_thw"]
        assert grid_thw.ndim == 2
        grid_thw_list = grid_thw.tolist()
        merge_size = self.visual.spatial_merge_size

        # Apply EVS to each video.
        video_embeds_out = []
        for video_idx, (emb, size) in enumerate(zip(video_embeds_split, grid_thw_list)):
            # Compute positions.
            timestamps = video_input.timestamps[video_idx]
            num_frames = len(timestamps)

            t, h, w = size
            if self.is_multimodal_pruning_enabled:
                # For each video, compute retention mask using EVS.
                # retention_mask: [11424].
                retention_mask = compute_retention_mask(
                    emb,
                    size,
                    spatial_merge_size=self.visual.spatial_merge_size,
                    q=self.video_pruning_rate,
                )
                # Apply retention mask.
                emb = emb[retention_mask]

                # Calculate the actual number of retained tokens per frame.
                num_frames, rows, cols = (
                    t,
                    h // merge_size,
                    w // merge_size,
                )
                retention_mask_thw = retention_mask.reshape(num_frames, rows, cols)
                num_tokens_per_frame = (
                    retention_mask_thw.sum(dim=(1, 2)).long().tolist()
                )
            else:
                feature_size = emb.shape[0] // num_frames
                num_tokens_per_frame = [feature_size] * num_frames
                retention_mask = None

            emb = self._create_final_video_embeddings(
                video_embeddings=emb,
                num_tokens_per_frame=num_tokens_per_frame,
                timestamps=timestamps,
                video_grid_thw=size,
                retention_mask=retention_mask,
            )

            video_embeds_out.append(emb)

        return tuple(video_embeds_out)

    def _create_final_video_embeddings(
        self,
        video_embeddings: torch.Tensor,
        num_tokens_per_frame: list[int],
        timestamps: list[float],
        video_grid_thw: list[int],
        retention_mask: torch.Tensor,
    ) -> torch.Tensor:
        """Create final embeddings that combine video embeddings with
        text embeddings of indicator tokens.

        These final embeddings contain:
        - Actual video embeddings in positions corresponding to video content
        - Text embeddings for indicator tokens (<img>, </img>, and
          frame separation text) in their respective positions

        These embeddings will replace the placeholder embeddings to create
        input_embeds for the LLM.
        """
        device = video_embeddings.device

        # Generate video replacement token IDs using get_video_repl
        # This tokenizes each frame separator independently, then uses pre-tokenized
        # special tokens to ensure consistent tokenization regardless of
        # num_tokens_per_frame values.
        video_repl = Qwen3VLMultiModalProcessor.get_video_repl(
            tokens_per_frame=num_tokens_per_frame,
            tokenizer=self._tokenizer,
            timestamps=timestamps,
            vision_start_token_id=self.config.vision_start_token_id,
            vision_end_token_id=self.config.vision_end_token_id,
            video_token_id=self.config.video_token_id,
            select_token_id=self.is_multimodal_pruning_enabled,
        )

        repl_token_ids = torch.tensor(video_repl.full, device=device)
        embed_token_id = _cached_tensor(self.config.video_token_id, device=device)
        is_video_embed = torch.isin(repl_token_ids, embed_token_id)

        # Get text embeddings for indicator tokens (has only `visual_dim``).
        text_embeddings = self.get_language_model().embed_input_ids(repl_token_ids)

        if self.use_deepstack:
            (
                deepstack_input_embeds,
                multimodal_embeddings,
            ) = self._compute_deepstack_embeds(
                inputs_embeds=text_embeddings,
                multimodal_embeddings=[video_embeddings],
                is_multimodal=is_video_embed,
            )
        else:
            deepstack_input_embeds = None
            multimodal_embeddings = [video_embeddings]

        merged_embeddings = _merge_multimodal_embeddings(
            inputs_embeds=text_embeddings,
            multimodal_embeddings=multimodal_embeddings,
            is_multimodal=is_video_embed,
        )

        to_concat = [merged_embeddings]
        if deepstack_input_embeds is not None:
            to_concat.append(
                deepstack_input_embeds.permute(1, 0, 2).reshape(
                    deepstack_input_embeds.shape[1], -1
                )
            )

        expanded_positions = None
        if self.is_multimodal_pruning_enabled:
            is_vision_start = repl_token_ids.eq(self.config.vision_start_token_id)
            expanded_positions = self._get_expanded_positions(
                device=merged_embeddings.device,
                seq_len=merged_embeddings.shape[0],
                video_grid_thw=video_grid_thw,
                num_tokens_per_frame=num_tokens_per_frame,
                timestamps=timestamps,
                is_video_embed=is_video_embed,
                is_vision_start=is_vision_start,
                retention_mask=retention_mask,
            )
            to_concat.append(expanded_positions)

        final_video_embeddings = torch.cat(to_concat, dim=-1)

        return final_video_embeddings

    def _get_expanded_positions(
        self,
        device,
        seq_len,
        video_grid_thw,
        num_tokens_per_frame,
        timestamps,
        is_video_embed,
        is_vision_start,
        retention_mask,
    ):
        embed_token_id = _cached_tensor(self.config.video_token_id, device=device)

        # Expand positions to match the full sequence length
        # (includes both video tokens and indicator tokens)
        # Shape: [full_length, 5] where positions are filled for video tokens
        # and zeros for indicator tokens.
        # Channel 3 flags VISION_START tokens so that
        # recompute_mrope_positions can reliably count timestamp tokens
        # (even when early frames have all video tokens pruned).
        # Channel 4 flags video-embedding tokens.
        expanded_positions = torch.zeros(
            seq_len,
            5,  # [t_index, h_index, w_index, is_vision_start, is_video]
            device=device,
            dtype=torch.long,
        )
        _, h, w = video_grid_thw
        merge_size = self.visual.spatial_merge_size
        num_frames = len(num_tokens_per_frame)
        unpruned_token_ids = Qwen3VLMultiModalProcessor.get_video_repl(
            tokens_per_frame=[(h // merge_size) * (w // merge_size)] * num_frames,
            tokenizer=self._tokenizer,
            timestamps=timestamps,
            vision_start_token_id=self.config.vision_start_token_id,
            vision_end_token_id=self.config.vision_end_token_id,
            video_token_id=self.config.video_token_id,
        ).full
        unpruned_token_ids_tensor = torch.tensor(unpruned_token_ids, device=device)
        mm_feature = MultiModalFeatureSpec(
            data=MultiModalKwargsItem(
                {
                    "video_grid_thw": MultiModalFieldElem(
                        data=torch.tensor(video_grid_thw),
                        field=None,  # HACK.
                    ),
                }
            ),
            modality="video",
            identifier="DUMMY",
            mm_position=PlaceholderRange(offset=0, length=len(unpruned_token_ids)),
        )
        original_mrope = (
            self.get_mrope_input_positions(
                input_tokens=unpruned_token_ids,
                mm_features=[mm_feature],
            )[0]
            .to(device)
            .permute(1, 0)
        )
        full_is_video_embed = unpruned_token_ids_tensor == embed_token_id
        expanded_positions[is_video_embed, :3] = original_mrope[full_is_video_embed][
            retention_mask
        ]
        expanded_positions[~is_video_embed, :3] = original_mrope[~full_is_video_embed]
        expanded_positions[..., 3] = is_vision_start
        expanded_positions[..., 4] = is_video_embed

        return expanded_positions

    def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
        mm_input_by_modality = {}
        for input_key in kwargs:
            if (
                input_key in ("pixel_values", "image_embeds")
                and "image" not in mm_input_by_modality
            ):
                mm_input_by_modality["image"] = self._parse_and_validate_image_input(
                    **kwargs
                )
            if (
                input_key in ("pixel_values_videos", "video_embeds")
                and "video" not in mm_input_by_modality
            ):
                mm_input_by_modality["video"] = self._parse_and_validate_video_input(
                    **kwargs
                )
        return mm_input_by_modality

    @staticmethod
    def _iter_mm_grid_hw(
        input_tokens: list[int],
        mm_features: list[MultiModalFeatureSpec],
        video_token_id: int,
        vision_start_token_id: int,
        vision_end_token_id: int,
        spatial_merge_size: int,
    ) -> Iterator[tuple[int, int, int, int]]:
        """Iterate over multimodal features and yield position info.

        Args:
            input_tokens: List of token IDs in the input sequence.
            mm_features: List of multimodal feature specifications containing
                image/video data and position information.
            video_token_id: Token ID used for video tokens.
            vision_start_token_id: Token ID marking the start of a vision sequence.
            vision_end_token_id: Token ID marking the end of a vision sequence.
            spatial_merge_size: Size of the spatial merge operation used to
                compute logical grid dimensions from the original feature grid.

        Yields:
            offset: Position of the first video/image token in the sequence.
            llm_grid_h: Logical grid height (may not match actual token count with EVS).
            llm_grid_w: Logical grid width (may not match actual token count with EVS).
            actual_num_tokens: Actual number of video/image tokens in the placeholder.
        """
        for mm_feature in sorted(mm_features, key=lambda f: f.mm_position.offset):
            offset = mm_feature.mm_position.offset
            if mm_feature.modality == "image":
                t, h, w = mm_feature.data["image_grid_thw"].data.tolist()
                assert t == 1, f"Image must have 1 frame, got {t}"
                llm_grid_h = h // spatial_merge_size
                llm_grid_w = w // spatial_merge_size
                yield offset, llm_grid_h, llm_grid_w, llm_grid_h * llm_grid_w
            elif mm_feature.modality == "video":
                t, h, w = mm_feature.data["video_grid_thw"].data.tolist()
                llm_grid_h = h // spatial_merge_size
                llm_grid_w = w // spatial_merge_size

                for _ in range(t):
                    # When EVS is enabled, some frames may have 0 video tokens in the
                    # placeholder. We use `vision_start_token_id` to locate each frame
                    # since it is always present for every frame.
                    # We then look for the first `video_token_id` after
                    # `vision_start_token_id` and before `vision_end_token_id`.
                    offset = input_tokens.index(vision_start_token_id, offset)
                    vision_end_offset = input_tokens.index(vision_end_token_id, offset)

                    try:
                        actual_num_tokens = 0
                        video_offset = input_tokens.index(
                            video_token_id, offset, vision_end_offset
                        )
                        # NOTE: looking at the
                        # `Qwen3VLMultiModalProcessor.get_video_repl` code, we can
                        # see that we can use the below formula to get the token
                        # count, since everything in between `video_offset` and
                        # `vision_end_offset` is populated as `video_token_id`.
                        # This saves us from manually counting the number tokens
                        # that match `video_token_id` in between.
                        actual_num_tokens += vision_end_offset - video_offset
                    except ValueError:
                        # No `video_token_id` in this frame (EVS with 0 tokens for
                        # this frame) -> use `offset + 1`` to move past
                        # `vision_start_token_id`.
                        video_offset = offset + 1

                    yield video_offset, llm_grid_h, llm_grid_w, actual_num_tokens
                    # Move offset past this frame for next iteration.
                    offset = vision_end_offset + 1
            else:
                raise ValueError(f"Unsupported modality: {mm_feature.modality}")

    def _get_evs_mask_segments(
        self, mm_position: PlaceholderRange, expected_frames: int
    ) -> list[torch.Tensor] | None:
        """Extract contiguous segments from EVS is_embed mask.

        The EVS (Efficient Video Sampling) mask marks which placeholder
        positions should be filled with video embeddings. This method splits
        the mask into contiguous segments, where each segment represents one
        retained frame.

        This is a pure function - it does not modify any state and always
        returns the same output for the same input (idempotent).

        Args:
            mm_position: MultiModal position containing the is_embed mask
            expected_frames: Expected number of frame segments

        Returns:
            List of tensors, each containing indices for one frame segment,
            or None if EVS is not enabled or validation fails.
        """
        is_embed_mask = getattr(mm_position, "is_embed", None)
        if is_embed_mask is None:
            return None

        # Find all True positions in the mask
        mask_tensor = torch.as_tensor(is_embed_mask, dtype=torch.bool).view(-1)
        true_indices = torch.nonzero(mask_tensor, as_tuple=False).flatten()
        if true_indices.numel() == 0:
            return None

        # Split into contiguous segments (where diff > 1 indicates a gap)
        if true_indices.numel() == 1:
            segments = [true_indices]
        else:
            diffs = torch.diff(true_indices)
            split_points = torch.nonzero(diffs != 1, as_tuple=False).flatten()
            if split_points.numel() == 0:
                segments = [true_indices]
            else:
                segments = torch.tensor_split(
                    true_indices, split_points.add(1).tolist()
                )

        # Validate segment count matches expected frames
        if len(segments) < expected_frames:
            logger.debug(
                "EVS mask segments (%d) do not match expected frames (%d)",
                len(segments),
                expected_frames,
            )
            return None

        return segments[:expected_frames]

    def _extract_frame_offsets_from_mask(
        self, mm_position: PlaceholderRange, expected_frames: int
    ) -> list[int] | None:
        """Return relative offsets for each EVS-retained frame.

        The prompt processor stores a boolean mask inside ``mm_position`` that
        marks which placeholder locations should be populated with video
        embeddings. By splitting that mask into contiguous runs we can recover
        the start of every retained frame without probing ``input_tokens``.

        Args:
            mm_position: MultiModal position containing the is_embed mask
            expected_frames: Expected number of frames

        Returns:
            List of starting offsets (relative to mm_position) for each frame,
            or None if EVS is not enabled.
        """
        segments = self._get_evs_mask_segments(mm_position, expected_frames)
        if segments is None:
            return None

        return [int(segment[0].item()) for segment in segments]

    def _get_actual_frame_token_counts(
        self, mm_position: PlaceholderRange, expected_frames: int
    ) -> list[int] | None:
        """Return actual token count for each EVS-retained frame.

        This function calculates the actual number of tokens per frame by
        analyzing the is_embed mask, accounting for EVS pruning. Each frame
        may have a different token count due to content-aware pruning.

        Args:
            mm_position: MultiModal position containing the is_embed mask
            expected_frames: Expected number of frames

        Returns:
            List of token counts for each frame, or None if EVS is not enabled.
        """
        segments = self._get_evs_mask_segments(mm_position, expected_frames)
        if segments is None:
            return None

        return [len(seg) for seg in segments]

    def get_mrope_input_positions(
        self,
        input_tokens: list[int],
        mm_features: list[MultiModalFeatureSpec],
    ) -> tuple[torch.Tensor, int]:
        return self._get_mrope_input_positions(
            input_tokens=input_tokens,
            mm_features=mm_features,
            config=self.config,
        )

    @staticmethod
    def _get_mrope_input_positions(
        input_tokens: list[int],
        mm_features: list[MultiModalFeatureSpec],
        config: Qwen3VLConfig,
    ):
        llm_pos_ids_list = []
        st = 0
        for (
            offset,
            llm_grid_h,
            llm_grid_w,
            actual_num_tokens,
        ) in Qwen3VLForConditionalGeneration._iter_mm_grid_hw(
            input_tokens,
            mm_features,
            video_token_id=config.video_token_id,
            vision_start_token_id=config.vision_start_token_id,
            vision_end_token_id=config.vision_end_token_id,
            spatial_merge_size=config.vision_config.spatial_merge_size,
        ):
            # Skip frames with 0 tokens (EVS placeholder with tokens lumped elsewhere)
            if actual_num_tokens == 0:
                continue

            text_len = offset - st
            st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
            llm_pos_ids_list.append(
                np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx
            )

            # Check if this is a "lumped placeholder" (all tokens from multiple frames
            # assigned to the 0-th frame - see
            # `Qwen3VLMultiModalProcessor.get_video_repl`.
            expected_tokens_per_frame = llm_grid_h * llm_grid_w
            if actual_num_tokens > expected_tokens_per_frame:
                # Lumped placeholder: create grid positions for all "logical" frames
                # represented.
                num_logical_frames = actual_num_tokens // expected_tokens_per_frame
                remainder = actual_num_tokens % expected_tokens_per_frame

                # Create positions for complete frames.
                for _ in range(num_logical_frames):
                    grid_indices = np.indices((1, llm_grid_h, llm_grid_w)).reshape(
                        3, -1
                    )
                    llm_pos_ids_list.append(grid_indices + text_len + st_idx)
                    st_idx = llm_pos_ids_list[-1].max() + 1
                    text_len = 0  # No text between frames within the lump

                # Handle remainder tokens if any (partial frame).
                # NOTE: this should never be the case. Should we have an assert?
                if remainder > 0:
                    # Create a partial grid - take first 'remainder' positions
                    full_grid = np.indices((1, llm_grid_h, llm_grid_w)).reshape(3, -1)
                    grid_indices = full_grid[:, :remainder]
                    llm_pos_ids_list.append(grid_indices + text_len + st_idx)
            else:
                # Normal case: frame has exactly the expected tokens (after actual EVS
                # pruning).
                grid_indices = np.indices((1, llm_grid_h, llm_grid_w)).reshape(3, -1)
                llm_pos_ids_list.append(grid_indices + text_len + st_idx)

            st = offset + actual_num_tokens

        if st < len(input_tokens):
            st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
            text_len = len(input_tokens) - st
            llm_pos_ids_list.append(
                np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx
            )

        llm_positions = np.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1)
        mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
        return torch.from_numpy(llm_positions), mrope_position_delta

    def recompute_mrope_positions(
        self,
        input_ids: list[int],
        multimodal_embeddings: MultiModalEmbeddings,
        mrope_positions: torch.LongTensor,
        num_computed_tokens: int,
    ) -> tuple[MultiModalEmbeddings, torch.Tensor, int]:
        """
        Update part of input mrope positions (starting with
        num_computed_tokens index). Original mrope_positions are computed
        for unpruned sequence and becomes incorrect once pruning occurs,
        so once we prune media tokens we should reflect this in the
        mrope_positions before we feed it to LLM.

        Args:
            input_ids: (N,) All input tokens of the prompt containing
                entire sequence.
            multimodal_embeddings: Tuple of multimodal embeddings that
                fits into the prefill chunk that is being processed.
            mrope_positions: Existing mrope positions (3, N) for entire
                sequence
            num_computed_tokens: A number of computed tokens so far.

        Returns:
            Tuple of (multimodal_embeddings, mrope_positions,
                mrope_position_delta).
        """
        return self._recompute_mrope_positions(
            input_ids=input_ids,
            multimodal_embeddings=multimodal_embeddings,
            mrope_positions=mrope_positions,
            num_computed_tokens=num_computed_tokens,
            image_token_id=self.config.image_token_id,
            video_token_id=self.config.video_token_id,
            vision_start_token_id=self.config.vision_start_token_id,
        )

    @staticmethod
    def _recompute_mrope_positions(
        input_ids: list[int],
        multimodal_embeddings: MultiModalEmbeddings,
        mrope_positions: torch.LongTensor,
        num_computed_tokens: int,
        vision_start_token_id: int,
        image_token_id: int,
        video_token_id: int,
    ) -> tuple[MultiModalEmbeddings, torch.Tensor, int]:
        # Device
        device = (
            multimodal_embeddings[0].device
            if len(multimodal_embeddings)
            else mrope_positions.device
        )

        # Tensors
        input_ids_t = torch.as_tensor(input_ids, device=device, dtype=torch.long)

        mm_embeddings_out = []
        mm_embeddings_pos = []
        # Strip position information from embeddings (last 5 channels)
        # For Qwen3 VL, handle potentially empty frames (from unpacking)
        for mm in multimodal_embeddings:
            if mm.shape[0] > 0:  # Only process non-empty frames
                mm_embeddings_out.append(mm[:, :-5])
                mm_embeddings_pos.append(mm[:, -5:].permute(1, 0).long())
            else:
                # Empty frame - keep as is
                mm_embeddings_out.append(mm)
                # Create empty position tensor with correct shape
                mm_embeddings_pos.append(
                    torch.empty(5, 0, device=device, dtype=torch.long)
                )

        positions, mrope_positions_delta = recompute_mrope_positions(
            input_ids_t,
            mm_embeddings_pos,
            mrope_positions,
            num_computed_tokens,
            vision_start_token_id,
            image_token_id,
            video_token_id,
        )

        return tuple(mm_embeddings_out), positions, mrope_positions_delta

    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings | None:
        mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs)
        if not mm_input_by_modality:
            return None

        # The result multimodal_embeddings is tuple of tensors, with each
        # tensor corresponding to a multimodal data item (image or video).
        multimodal_embeddings: list[torch.Tensor] = []

        # NOTE: It is important to iterate over the keys in this dictionary
        # to preserve the order of the modalities.
        for modality in mm_input_by_modality:
            multimodal_input = mm_input_by_modality[modality]
            if modality == "image":
                image_embeddings = self._process_image_input(multimodal_input)
                image_embeddings = self._postprocess_image_embeds_evs(
                    image_embeddings, multimodal_input
                )
                multimodal_embeddings.extend(image_embeddings)
            if modality == "video":
                video_embeddings = self._process_video_input(multimodal_input)
                if self.is_multimodal_pruning_enabled:
                    video_embeddings = self._postprocess_video_embeds_evs(
                        video_embeddings, multimodal_input
                    )
                multimodal_embeddings.extend(video_embeddings)

        embeddings_tuple = tuple(multimodal_embeddings)
        return embeddings_tuple

    def _compute_deepstack_embeds(
        self,
        inputs_embeds: torch.Tensor,
        multimodal_embeddings: MultiModalEmbeddings,
        is_multimodal: torch.Tensor,
    ) -> tuple[torch.Tensor, MultiModalEmbeddings]:
        visual_lens = [len(x) for x in multimodal_embeddings]
        multimodal_embeddings_cat = torch.cat(multimodal_embeddings, dim=0)

        (
            multimodal_embeddings_main,
            multimodal_embeddings_multiscale,
        ) = torch.split(
            multimodal_embeddings_cat,
            [self.visual_dim, self.multiscale_dim],
            dim=-1,
        )

        multimodal_embeddings = torch.split(
            multimodal_embeddings_main, visual_lens, dim=0
        )
        multimodal_embeddings_multiscale = torch.split(
            multimodal_embeddings_multiscale, visual_lens, dim=0
        )

        deepstack_input_embeds = inputs_embeds.new_zeros(
            inputs_embeds.size(0), self.deepstack_num_level * inputs_embeds.size(1)
        )

        deepstack_input_embeds = _merge_multimodal_embeddings(
            inputs_embeds=deepstack_input_embeds,
            multimodal_embeddings=multimodal_embeddings_multiscale,
            is_multimodal=is_multimodal,
        )
        deepstack_input_embeds = deepstack_input_embeds.view(
            inputs_embeds.shape[0], self.deepstack_num_level, self.visual_dim
        )
        deepstack_input_embeds = deepstack_input_embeds.permute(1, 0, 2)

        return deepstack_input_embeds, multimodal_embeddings

    def embed_input_ids(
        self,
        input_ids: torch.Tensor,
        multimodal_embeddings: MultiModalEmbeddings | None = None,
        *,
        is_multimodal: torch.Tensor | None = None,
    ) -> torch.Tensor:
        inputs_embeds = self._embed_text_input_ids(
            input_ids,
            self.language_model.embed_input_ids,
            is_multimodal=is_multimodal,
        )

        if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
            return inputs_embeds

        is_multimodal = _require_is_multimodal(is_multimodal)

        if self.use_deepstack:
            (
                deepstack_input_embeds,
                multimodal_embeddings,
            ) = self._compute_deepstack_embeds(
                inputs_embeds=inputs_embeds,
                multimodal_embeddings=multimodal_embeddings,
                is_multimodal=is_multimodal,
            )
        else:
            deepstack_input_embeds = None

        inputs_embeds = _merge_multimodal_embeddings(
            inputs_embeds=inputs_embeds,
            multimodal_embeddings=multimodal_embeddings,
            is_multimodal=is_multimodal,
        )

        if deepstack_input_embeds is not None:
            self._set_deepstack_input_embeds(deepstack_input_embeds)

        return inputs_embeds

    def forward(
        self,
        input_ids: torch.Tensor | None,
        positions: torch.Tensor,
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
        **kwargs: object,
    ) -> torch.Tensor | IntermediateTensors:
        """Run forward pass for Qwen3VL.

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
            positions: Flattened (concatenated) position ids corresponding to a
                batch.
                **NOTE**: If mrope is enabled (default setting for Qwen3VL
                opensource models), the shape will be `(3, seq_len)`,
                otherwise it will be `(seq_len,).
            intermediate_tensors: Intermediate tensors from previous pipeline
                stages.
            inputs_embeds: Pre-computed input embeddings.
            **kwargs: Additional keyword arguments including:
                - pixel_values: Pixel values to be fed to a model.
                    `None` if no images are passed.
                - image_grid_thw: Tensor `(n_images, 3)` of image 3D grid in
                    LLM. `None` if no images are passed.
                - pixel_values_videos: Pixel values of videos to be fed to a
                    model. `None` if no videos are passed.
                - video_grid_thw: Tensor `(n_videos, 3)` of video 3D grid in
                    LLM. `None` if no videos are passed.
        """

        if intermediate_tensors is not None:
            inputs_embeds = None

        if inputs_embeds is not None and get_pp_group().is_first_rank:
            deepstack_input_embeds = self._get_deepstack_input_embeds(
                inputs_embeds.size(0)
            )
        else:
            deepstack_input_embeds = None

        hidden_states = self.language_model.model(
            input_ids=input_ids,
            positions=positions,
            intermediate_tensors=intermediate_tensors,
            inputs_embeds=inputs_embeds,
            # args for deepstack
            deepstack_input_embeds=deepstack_input_embeds,
        )

        if inputs_embeds is not None and get_pp_group().is_first_rank:
            self._clear_deepstack_input_embeds(inputs_embeds.size(0))

        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor | None:
        return self.language_model.compute_logits(hidden_states)

    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)

    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(
            language_model="language_model",
            connector=["visual.merger", "visual.deepstack_merger_list"],
            tower_model="visual.",
        )

    def get_num_mm_encoder_tokens(
        self,
        num_image_tokens: int,
    ) -> int:
        hf_config = self.config
        vision_config = hf_config.vision_config
        merge_size = vision_config.spatial_merge_size

        return num_image_tokens * merge_size**2

    def get_num_mm_connector_tokens(
        self,
        num_vision_tokens: int,
    ) -> int:
        hf_config = self.config
        vision_config = hf_config.vision_config
        merge_size = vision_config.spatial_merge_size
        return num_vision_tokens // merge_size**2

_create_final_video_embeddings

_create_final_video_embeddings(
    video_embeddings: Tensor,
    num_tokens_per_frame: list[int],
    timestamps: list[float],
    video_grid_thw: list[int],
    retention_mask: Tensor,
) -> Tensor

Create final embeddings that combine video embeddings with text embeddings of indicator tokens.

These final embeddings contain: - Actual video embeddings in positions corresponding to video content - Text embeddings for indicator tokens (, , and frame separation text) in their respective positions

These embeddings will replace the placeholder embeddings to create input_embeds for the LLM.

Source code in vllm/model_executor/models/qwen3_vl.py
def _create_final_video_embeddings(
    self,
    video_embeddings: torch.Tensor,
    num_tokens_per_frame: list[int],
    timestamps: list[float],
    video_grid_thw: list[int],
    retention_mask: torch.Tensor,
) -> torch.Tensor:
    """Create final embeddings that combine video embeddings with
    text embeddings of indicator tokens.

    These final embeddings contain:
    - Actual video embeddings in positions corresponding to video content
    - Text embeddings for indicator tokens (<img>, </img>, and
      frame separation text) in their respective positions

    These embeddings will replace the placeholder embeddings to create
    input_embeds for the LLM.
    """
    device = video_embeddings.device

    # Generate video replacement token IDs using get_video_repl
    # This tokenizes each frame separator independently, then uses pre-tokenized
    # special tokens to ensure consistent tokenization regardless of
    # num_tokens_per_frame values.
    video_repl = Qwen3VLMultiModalProcessor.get_video_repl(
        tokens_per_frame=num_tokens_per_frame,
        tokenizer=self._tokenizer,
        timestamps=timestamps,
        vision_start_token_id=self.config.vision_start_token_id,
        vision_end_token_id=self.config.vision_end_token_id,
        video_token_id=self.config.video_token_id,
        select_token_id=self.is_multimodal_pruning_enabled,
    )

    repl_token_ids = torch.tensor(video_repl.full, device=device)
    embed_token_id = _cached_tensor(self.config.video_token_id, device=device)
    is_video_embed = torch.isin(repl_token_ids, embed_token_id)

    # Get text embeddings for indicator tokens (has only `visual_dim``).
    text_embeddings = self.get_language_model().embed_input_ids(repl_token_ids)

    if self.use_deepstack:
        (
            deepstack_input_embeds,
            multimodal_embeddings,
        ) = self._compute_deepstack_embeds(
            inputs_embeds=text_embeddings,
            multimodal_embeddings=[video_embeddings],
            is_multimodal=is_video_embed,
        )
    else:
        deepstack_input_embeds = None
        multimodal_embeddings = [video_embeddings]

    merged_embeddings = _merge_multimodal_embeddings(
        inputs_embeds=text_embeddings,
        multimodal_embeddings=multimodal_embeddings,
        is_multimodal=is_video_embed,
    )

    to_concat = [merged_embeddings]
    if deepstack_input_embeds is not None:
        to_concat.append(
            deepstack_input_embeds.permute(1, 0, 2).reshape(
                deepstack_input_embeds.shape[1], -1
            )
        )

    expanded_positions = None
    if self.is_multimodal_pruning_enabled:
        is_vision_start = repl_token_ids.eq(self.config.vision_start_token_id)
        expanded_positions = self._get_expanded_positions(
            device=merged_embeddings.device,
            seq_len=merged_embeddings.shape[0],
            video_grid_thw=video_grid_thw,
            num_tokens_per_frame=num_tokens_per_frame,
            timestamps=timestamps,
            is_video_embed=is_video_embed,
            is_vision_start=is_vision_start,
            retention_mask=retention_mask,
        )
        to_concat.append(expanded_positions)

    final_video_embeddings = torch.cat(to_concat, dim=-1)

    return final_video_embeddings

_extract_frame_offsets_from_mask

_extract_frame_offsets_from_mask(
    mm_position: PlaceholderRange, expected_frames: int
) -> list[int] | None

Return relative offsets for each EVS-retained frame.

The prompt processor stores a boolean mask inside mm_position that marks which placeholder locations should be populated with video embeddings. By splitting that mask into contiguous runs we can recover the start of every retained frame without probing input_tokens.

Parameters:

Name Type Description Default
mm_position PlaceholderRange

MultiModal position containing the is_embed mask

required
expected_frames int

Expected number of frames

required

Returns:

Type Description
list[int] | None

List of starting offsets (relative to mm_position) for each frame,

list[int] | None

or None if EVS is not enabled.

Source code in vllm/model_executor/models/qwen3_vl.py
def _extract_frame_offsets_from_mask(
    self, mm_position: PlaceholderRange, expected_frames: int
) -> list[int] | None:
    """Return relative offsets for each EVS-retained frame.

    The prompt processor stores a boolean mask inside ``mm_position`` that
    marks which placeholder locations should be populated with video
    embeddings. By splitting that mask into contiguous runs we can recover
    the start of every retained frame without probing ``input_tokens``.

    Args:
        mm_position: MultiModal position containing the is_embed mask
        expected_frames: Expected number of frames

    Returns:
        List of starting offsets (relative to mm_position) for each frame,
        or None if EVS is not enabled.
    """
    segments = self._get_evs_mask_segments(mm_position, expected_frames)
    if segments is None:
        return None

    return [int(segment[0].item()) for segment in segments]

_get_actual_frame_token_counts

_get_actual_frame_token_counts(
    mm_position: PlaceholderRange, expected_frames: int
) -> list[int] | None

Return actual token count for each EVS-retained frame.

This function calculates the actual number of tokens per frame by analyzing the is_embed mask, accounting for EVS pruning. Each frame may have a different token count due to content-aware pruning.

Parameters:

Name Type Description Default
mm_position PlaceholderRange

MultiModal position containing the is_embed mask

required
expected_frames int

Expected number of frames

required

Returns:

Type Description
list[int] | None

List of token counts for each frame, or None if EVS is not enabled.

Source code in vllm/model_executor/models/qwen3_vl.py
def _get_actual_frame_token_counts(
    self, mm_position: PlaceholderRange, expected_frames: int
) -> list[int] | None:
    """Return actual token count for each EVS-retained frame.

    This function calculates the actual number of tokens per frame by
    analyzing the is_embed mask, accounting for EVS pruning. Each frame
    may have a different token count due to content-aware pruning.

    Args:
        mm_position: MultiModal position containing the is_embed mask
        expected_frames: Expected number of frames

    Returns:
        List of token counts for each frame, or None if EVS is not enabled.
    """
    segments = self._get_evs_mask_segments(mm_position, expected_frames)
    if segments is None:
        return None

    return [len(seg) for seg in segments]

_get_evs_mask_segments

_get_evs_mask_segments(
    mm_position: PlaceholderRange, expected_frames: int
) -> list[Tensor] | None

Extract contiguous segments from EVS is_embed mask.

The EVS (Efficient Video Sampling) mask marks which placeholder positions should be filled with video embeddings. This method splits the mask into contiguous segments, where each segment represents one retained frame.

This is a pure function - it does not modify any state and always returns the same output for the same input (idempotent).

Parameters:

Name Type Description Default
mm_position PlaceholderRange

MultiModal position containing the is_embed mask

required
expected_frames int

Expected number of frame segments

required

Returns:

Type Description
list[Tensor] | None

List of tensors, each containing indices for one frame segment,

list[Tensor] | None

or None if EVS is not enabled or validation fails.

Source code in vllm/model_executor/models/qwen3_vl.py
def _get_evs_mask_segments(
    self, mm_position: PlaceholderRange, expected_frames: int
) -> list[torch.Tensor] | None:
    """Extract contiguous segments from EVS is_embed mask.

    The EVS (Efficient Video Sampling) mask marks which placeholder
    positions should be filled with video embeddings. This method splits
    the mask into contiguous segments, where each segment represents one
    retained frame.

    This is a pure function - it does not modify any state and always
    returns the same output for the same input (idempotent).

    Args:
        mm_position: MultiModal position containing the is_embed mask
        expected_frames: Expected number of frame segments

    Returns:
        List of tensors, each containing indices for one frame segment,
        or None if EVS is not enabled or validation fails.
    """
    is_embed_mask = getattr(mm_position, "is_embed", None)
    if is_embed_mask is None:
        return None

    # Find all True positions in the mask
    mask_tensor = torch.as_tensor(is_embed_mask, dtype=torch.bool).view(-1)
    true_indices = torch.nonzero(mask_tensor, as_tuple=False).flatten()
    if true_indices.numel() == 0:
        return None

    # Split into contiguous segments (where diff > 1 indicates a gap)
    if true_indices.numel() == 1:
        segments = [true_indices]
    else:
        diffs = torch.diff(true_indices)
        split_points = torch.nonzero(diffs != 1, as_tuple=False).flatten()
        if split_points.numel() == 0:
            segments = [true_indices]
        else:
            segments = torch.tensor_split(
                true_indices, split_points.add(1).tolist()
            )

    # Validate segment count matches expected frames
    if len(segments) < expected_frames:
        logger.debug(
            "EVS mask segments (%d) do not match expected frames (%d)",
            len(segments),
            expected_frames,
        )
        return None

    return segments[:expected_frames]

_iter_mm_grid_hw staticmethod

_iter_mm_grid_hw(
    input_tokens: list[int],
    mm_features: list[MultiModalFeatureSpec],
    video_token_id: int,
    vision_start_token_id: int,
    vision_end_token_id: int,
    spatial_merge_size: int,
) -> Iterator[tuple[int, int, int, int]]

Iterate over multimodal features and yield position info.

Parameters:

Name Type Description Default
input_tokens list[int]

List of token IDs in the input sequence.

required
mm_features list[MultiModalFeatureSpec]

List of multimodal feature specifications containing image/video data and position information.

required
video_token_id int

Token ID used for video tokens.

required
vision_start_token_id int

Token ID marking the start of a vision sequence.

required
vision_end_token_id int

Token ID marking the end of a vision sequence.

required
spatial_merge_size int

Size of the spatial merge operation used to compute logical grid dimensions from the original feature grid.

required

Yields:

Name Type Description
offset int

Position of the first video/image token in the sequence.

llm_grid_h int

Logical grid height (may not match actual token count with EVS).

llm_grid_w int

Logical grid width (may not match actual token count with EVS).

actual_num_tokens int

Actual number of video/image tokens in the placeholder.

Source code in vllm/model_executor/models/qwen3_vl.py
@staticmethod
def _iter_mm_grid_hw(
    input_tokens: list[int],
    mm_features: list[MultiModalFeatureSpec],
    video_token_id: int,
    vision_start_token_id: int,
    vision_end_token_id: int,
    spatial_merge_size: int,
) -> Iterator[tuple[int, int, int, int]]:
    """Iterate over multimodal features and yield position info.

    Args:
        input_tokens: List of token IDs in the input sequence.
        mm_features: List of multimodal feature specifications containing
            image/video data and position information.
        video_token_id: Token ID used for video tokens.
        vision_start_token_id: Token ID marking the start of a vision sequence.
        vision_end_token_id: Token ID marking the end of a vision sequence.
        spatial_merge_size: Size of the spatial merge operation used to
            compute logical grid dimensions from the original feature grid.

    Yields:
        offset: Position of the first video/image token in the sequence.
        llm_grid_h: Logical grid height (may not match actual token count with EVS).
        llm_grid_w: Logical grid width (may not match actual token count with EVS).
        actual_num_tokens: Actual number of video/image tokens in the placeholder.
    """
    for mm_feature in sorted(mm_features, key=lambda f: f.mm_position.offset):
        offset = mm_feature.mm_position.offset
        if mm_feature.modality == "image":
            t, h, w = mm_feature.data["image_grid_thw"].data.tolist()
            assert t == 1, f"Image must have 1 frame, got {t}"
            llm_grid_h = h // spatial_merge_size
            llm_grid_w = w // spatial_merge_size
            yield offset, llm_grid_h, llm_grid_w, llm_grid_h * llm_grid_w
        elif mm_feature.modality == "video":
            t, h, w = mm_feature.data["video_grid_thw"].data.tolist()
            llm_grid_h = h // spatial_merge_size
            llm_grid_w = w // spatial_merge_size

            for _ in range(t):
                # When EVS is enabled, some frames may have 0 video tokens in the
                # placeholder. We use `vision_start_token_id` to locate each frame
                # since it is always present for every frame.
                # We then look for the first `video_token_id` after
                # `vision_start_token_id` and before `vision_end_token_id`.
                offset = input_tokens.index(vision_start_token_id, offset)
                vision_end_offset = input_tokens.index(vision_end_token_id, offset)

                try:
                    actual_num_tokens = 0
                    video_offset = input_tokens.index(
                        video_token_id, offset, vision_end_offset
                    )
                    # NOTE: looking at the
                    # `Qwen3VLMultiModalProcessor.get_video_repl` code, we can
                    # see that we can use the below formula to get the token
                    # count, since everything in between `video_offset` and
                    # `vision_end_offset` is populated as `video_token_id`.
                    # This saves us from manually counting the number tokens
                    # that match `video_token_id` in between.
                    actual_num_tokens += vision_end_offset - video_offset
                except ValueError:
                    # No `video_token_id` in this frame (EVS with 0 tokens for
                    # this frame) -> use `offset + 1`` to move past
                    # `vision_start_token_id`.
                    video_offset = offset + 1

                yield video_offset, llm_grid_h, llm_grid_w, actual_num_tokens
                # Move offset past this frame for next iteration.
                offset = vision_end_offset + 1
        else:
            raise ValueError(f"Unsupported modality: {mm_feature.modality}")

_postprocess_image_embeds_evs

_postprocess_image_embeds_evs(
    image_embeds_split: tuple[Tensor, ...],
    image_input: Qwen2_5_VLImageInputs,
) -> tuple[Tensor, ...]

Append mrope positions for each for images. This is necessary to recover correct mrope positions after video pruning

Parameters:

Name Type Description Default
image_embeds_split tuple[Tensor, ...]

Tuple of image embeddings for each image item.

required
image_input Qwen2_5_VLImageInputs

Image input data.

required

Returns:

Type Description
Tensor

Tuple of image embeddings for each image item.

...

Resulting embeddings will have extra 5 channels for

tuple[Tensor, ...]

computed mrope positions, consistent with video embeddings.

Source code in vllm/model_executor/models/qwen3_vl.py
def _postprocess_image_embeds_evs(
    self,
    image_embeds_split: tuple[torch.Tensor, ...],
    image_input: Qwen2_5_VLImageInputs,
) -> tuple[torch.Tensor, ...]:
    """
    Append mrope positions for each for images.
    This is necessary to recover correct mrope
    positions after video pruning

    Args:
        image_embeds_split: Tuple of image embeddings for
            each image item.
        image_input: Image input data.

    Returns:
        Tuple of image embeddings for each image item.
        Resulting embeddings will have extra 5 channels for
        computed mrope positions, consistent with video embeddings.
    """
    if self.is_multimodal_pruning_enabled:
        merge_size = self.visual.spatial_merge_size
        grid_thw = image_input["image_grid_thw"]
        grid_thw_list = grid_thw.tolist()
        image_embeds_out = []
        for emb, size in zip(image_embeds_split, grid_thw_list):
            positions = compute_mrope_for_media(size, merge_size).to(emb.device)
            positions = torch.cat(
                [
                    positions,
                    torch.zeros_like(
                        positions[:, 0:1]
                    ),  # Dummy extra fifth channel
                ],
                dim=1,
            )
            emb = torch.cat([emb, positions], dim=1)
            image_embeds_out.append(emb)
        image_embeds_split = tuple(image_embeds_out)
    return image_embeds_split

_postprocess_video_embeds_evs

_postprocess_video_embeds_evs(
    video_embeds_split: tuple[Tensor, ...],
    video_input: Qwen2_5_VLVideoInputs,
) -> tuple[Tensor, ...]

Prunes video embeddings via Efficient Video Sampling (EVS) and then appends mrope positions for each retained embeddings

Parameters:

Name Type Description Default
video_embeds_split tuple[Tensor, ...]

Tuple of video embeddings for each video item.

required
video_input Qwen2_5_VLVideoInputs

Video input data.

required

Returns:

Type Description
Tensor

Tuple of video embeddings for each video item.

...

Resulting embeddings will have extra 5 channels for computed mrope

tuple[Tensor, ...]

positions, and whether the index corresponds to a video embedding.

Source code in vllm/model_executor/models/qwen3_vl.py
def _postprocess_video_embeds_evs(
    self,
    video_embeds_split: tuple[torch.Tensor, ...],
    video_input: Qwen2_5_VLVideoInputs,
) -> tuple[torch.Tensor, ...]:
    """
    Prunes video embeddings via Efficient Video Sampling (EVS)
    and then appends mrope positions for each retained embeddings

    Args:
        video_embeds_split: Tuple of video embeddings for each video item.
        video_input: Video input data.

    Returns:
        Tuple of video embeddings for each video item.
        Resulting embeddings will have extra 5 channels for computed mrope
        positions, and whether the index corresponds to a video embedding.
    """
    grid_thw = video_input["video_grid_thw"]
    assert grid_thw.ndim == 2
    grid_thw_list = grid_thw.tolist()
    merge_size = self.visual.spatial_merge_size

    # Apply EVS to each video.
    video_embeds_out = []
    for video_idx, (emb, size) in enumerate(zip(video_embeds_split, grid_thw_list)):
        # Compute positions.
        timestamps = video_input.timestamps[video_idx]
        num_frames = len(timestamps)

        t, h, w = size
        if self.is_multimodal_pruning_enabled:
            # For each video, compute retention mask using EVS.
            # retention_mask: [11424].
            retention_mask = compute_retention_mask(
                emb,
                size,
                spatial_merge_size=self.visual.spatial_merge_size,
                q=self.video_pruning_rate,
            )
            # Apply retention mask.
            emb = emb[retention_mask]

            # Calculate the actual number of retained tokens per frame.
            num_frames, rows, cols = (
                t,
                h // merge_size,
                w // merge_size,
            )
            retention_mask_thw = retention_mask.reshape(num_frames, rows, cols)
            num_tokens_per_frame = (
                retention_mask_thw.sum(dim=(1, 2)).long().tolist()
            )
        else:
            feature_size = emb.shape[0] // num_frames
            num_tokens_per_frame = [feature_size] * num_frames
            retention_mask = None

        emb = self._create_final_video_embeddings(
            video_embeddings=emb,
            num_tokens_per_frame=num_tokens_per_frame,
            timestamps=timestamps,
            video_grid_thw=size,
            retention_mask=retention_mask,
        )

        video_embeds_out.append(emb)

    return tuple(video_embeds_out)

forward

forward(
    input_ids: Tensor | None,
    positions: Tensor,
    intermediate_tensors: IntermediateTensors | None = None,
    inputs_embeds: Tensor | None = None,
    **kwargs: object,
) -> Tensor | IntermediateTensors

Run forward pass for Qwen3VL.

Parameters:

Name Type Description Default
input_ids Tensor | None

Flattened (concatenated) input_ids corresponding to a batch.

required
positions Tensor

Flattened (concatenated) position ids corresponding to a batch. NOTE: If mrope is enabled (default setting for Qwen3VL opensource models), the shape will be (3, seq_len), otherwise it will be `(seq_len,).

required
intermediate_tensors IntermediateTensors | None

Intermediate tensors from previous pipeline stages.

None
inputs_embeds Tensor | None

Pre-computed input embeddings.

None
**kwargs object

Additional keyword arguments including: - pixel_values: Pixel values to be fed to a model. None if no images are passed. - image_grid_thw: Tensor (n_images, 3) of image 3D grid in LLM. None if no images are passed. - pixel_values_videos: Pixel values of videos to be fed to a model. None if no videos are passed. - video_grid_thw: Tensor (n_videos, 3) of video 3D grid in LLM. None if no videos are passed.

{}
Source code in vllm/model_executor/models/qwen3_vl.py
def forward(
    self,
    input_ids: torch.Tensor | None,
    positions: torch.Tensor,
    intermediate_tensors: IntermediateTensors | None = None,
    inputs_embeds: torch.Tensor | None = None,
    **kwargs: object,
) -> torch.Tensor | IntermediateTensors:
    """Run forward pass for Qwen3VL.

    Args:
        input_ids: Flattened (concatenated) input_ids corresponding to a
            batch.
        positions: Flattened (concatenated) position ids corresponding to a
            batch.
            **NOTE**: If mrope is enabled (default setting for Qwen3VL
            opensource models), the shape will be `(3, seq_len)`,
            otherwise it will be `(seq_len,).
        intermediate_tensors: Intermediate tensors from previous pipeline
            stages.
        inputs_embeds: Pre-computed input embeddings.
        **kwargs: Additional keyword arguments including:
            - pixel_values: Pixel values to be fed to a model.
                `None` if no images are passed.
            - image_grid_thw: Tensor `(n_images, 3)` of image 3D grid in
                LLM. `None` if no images are passed.
            - pixel_values_videos: Pixel values of videos to be fed to a
                model. `None` if no videos are passed.
            - video_grid_thw: Tensor `(n_videos, 3)` of video 3D grid in
                LLM. `None` if no videos are passed.
    """

    if intermediate_tensors is not None:
        inputs_embeds = None

    if inputs_embeds is not None and get_pp_group().is_first_rank:
        deepstack_input_embeds = self._get_deepstack_input_embeds(
            inputs_embeds.size(0)
        )
    else:
        deepstack_input_embeds = None

    hidden_states = self.language_model.model(
        input_ids=input_ids,
        positions=positions,
        intermediate_tensors=intermediate_tensors,
        inputs_embeds=inputs_embeds,
        # args for deepstack
        deepstack_input_embeds=deepstack_input_embeds,
    )

    if inputs_embeds is not None and get_pp_group().is_first_rank:
        self._clear_deepstack_input_embeds(inputs_embeds.size(0))

    return hidden_states

get_mm_mapping

get_mm_mapping() -> MultiModelKeys

Get the module prefix in multimodal models

Source code in vllm/model_executor/models/qwen3_vl.py
def get_mm_mapping(self) -> MultiModelKeys:
    """
    Get the module prefix in multimodal models
    """
    return MultiModelKeys.from_string_field(
        language_model="language_model",
        connector=["visual.merger", "visual.deepstack_merger_list"],
        tower_model="visual.",
    )

recompute_mrope_positions

recompute_mrope_positions(
    input_ids: list[int],
    multimodal_embeddings: MultiModalEmbeddings,
    mrope_positions: LongTensor,
    num_computed_tokens: int,
) -> tuple[MultiModalEmbeddings, Tensor, int]

Update part of input mrope positions (starting with num_computed_tokens index). Original mrope_positions are computed for unpruned sequence and becomes incorrect once pruning occurs, so once we prune media tokens we should reflect this in the mrope_positions before we feed it to LLM.

Parameters:

Name Type Description Default
input_ids list[int]

(N,) All input tokens of the prompt containing entire sequence.

required
multimodal_embeddings MultiModalEmbeddings

Tuple of multimodal embeddings that fits into the prefill chunk that is being processed.

required
mrope_positions LongTensor

Existing mrope positions (3, N) for entire sequence

required
num_computed_tokens int

A number of computed tokens so far.

required

Returns:

Type Description
tuple[MultiModalEmbeddings, Tensor, int]

Tuple of (multimodal_embeddings, mrope_positions, mrope_position_delta).

Source code in vllm/model_executor/models/qwen3_vl.py
def recompute_mrope_positions(
    self,
    input_ids: list[int],
    multimodal_embeddings: MultiModalEmbeddings,
    mrope_positions: torch.LongTensor,
    num_computed_tokens: int,
) -> tuple[MultiModalEmbeddings, torch.Tensor, int]:
    """
    Update part of input mrope positions (starting with
    num_computed_tokens index). Original mrope_positions are computed
    for unpruned sequence and becomes incorrect once pruning occurs,
    so once we prune media tokens we should reflect this in the
    mrope_positions before we feed it to LLM.

    Args:
        input_ids: (N,) All input tokens of the prompt containing
            entire sequence.
        multimodal_embeddings: Tuple of multimodal embeddings that
            fits into the prefill chunk that is being processed.
        mrope_positions: Existing mrope positions (3, N) for entire
            sequence
        num_computed_tokens: A number of computed tokens so far.

    Returns:
        Tuple of (multimodal_embeddings, mrope_positions,
            mrope_position_delta).
    """
    return self._recompute_mrope_positions(
        input_ids=input_ids,
        multimodal_embeddings=multimodal_embeddings,
        mrope_positions=mrope_positions,
        num_computed_tokens=num_computed_tokens,
        image_token_id=self.config.image_token_id,
        video_token_id=self.config.video_token_id,
        vision_start_token_id=self.config.vision_start_token_id,
    )

Qwen3VLMultiModalProcessor

Bases: BaseMultiModalProcessor[Qwen3VLProcessingInfo]

Source code in vllm/model_executor/models/qwen3_vl.py
class Qwen3VLMultiModalProcessor(BaseMultiModalProcessor[Qwen3VLProcessingInfo]):
    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
        tok_kwargs: Mapping[str, object],
    ) -> BatchFeature:
        mm_data = dict(mm_data)
        processor = self.info.get_hf_processor(**mm_kwargs)

        # Separate video processing from image processing. Because the videos
        # are processed into several image patches
        if videos := mm_data.pop("videos", []):
            video_grid_thw_lst = []
            pixel_values_videos_lst = []
            timestamps_per_video = []

            for item in videos:
                video_array, metadata = item

                # NOTE: @JJJYmmm new attr metadata.frames_indices indicates
                # the sampled frames indices of pre-sampled videos, which is
                # used to calculate the timestamps. Make sure that
                # do_sample_frames in mm_kwargs is false for presampled videos.

                # NOTE: a copy of is created to update do_sample_frames,
                # otherwise mm_hash for the object will be incorrect.
                video_mm_kwargs = dict(**mm_kwargs)
                if "do_sample_frames" not in video_mm_kwargs:
                    # qwen_vl_utils already has "do_sample_frames" in
                    # mm_kwargs, don't overwrite it.
                    video_mm_kwargs["do_sample_frames"] = metadata.get(
                        "do_sample_frames", False
                    )

                metadata = VideoMetadata(
                    **{k: metadata[k] for k in metadata if k != "do_sample_frames"}
                )

                # Compute timestamps here where we have access to metadata
                timestamps = self.info._get_video_second_idx(
                    metadata=metadata,
                    do_sample_frames=video_mm_kwargs["do_sample_frames"],
                    sampled_fps=video_mm_kwargs.get("fps"),
                )
                timestamps_per_video.append(timestamps)

                video_mm_data = dict()
                video_mm_data["videos"] = [[video_array]]
                video_mm_data["video_metadata"] = [[metadata]]

                video_outputs = super()._call_hf_processor(
                    prompt="<|vision_start|><|video_pad|><|vision_end|>",
                    mm_data=video_mm_data,
                    mm_kwargs=video_mm_kwargs,
                    tok_kwargs=tok_kwargs,
                )

                merge_size = processor.video_processor.merge_size
                # Get video grid info for EVS calculation.
                video_grid_thw = video_outputs["video_grid_thw"]
                num_frames = int(video_grid_thw[0, 0])
                tokens_per_frame_base = int(video_grid_thw[0, 1:].prod()) // (
                    merge_size**2
                )

                # Apply EVS if enabled.
                video_pruning_rate = self.info.ctx.get_mm_config().video_pruning_rate
                if video_pruning_rate is not None and video_pruning_rate > 0.0:
                    num_tokens = compute_retained_tokens_count(
                        tokens_per_frame=tokens_per_frame_base,
                        num_frames=num_frames,
                        q=video_pruning_rate,
                    )
                    # Here we just need placeholders that won't actually be replaced -
                    # we just need to make sure the total number of tokens is correct
                    # assign all tokens to the first frame.
                    tokens_per_frame = [num_tokens] + [0] * (num_frames - 1)
                    select_token_id = False
                else:
                    tokens_per_frame = [tokens_per_frame_base] * num_frames
                    select_token_id = True

                # Generate the video replacement with EVS-adjusted token counts
                tokenizer = self.info.get_tokenizer()
                hf_config = self.info.get_hf_config()
                video_repl = Qwen3VLMultiModalProcessor.get_video_repl(
                    tokens_per_frame=tokens_per_frame,
                    timestamps=timestamps,
                    tokenizer=tokenizer,
                    vision_start_token_id=hf_config.vision_start_token_id,
                    vision_end_token_id=hf_config.vision_end_token_id,
                    video_token_id=hf_config.video_token_id,
                    select_token_id=select_token_id,
                )

                # Convert token IDs to text for the HF processor flow
                video_placeholder = tokenizer.decode(
                    video_repl.full, skip_special_tokens=False
                )
                input_ids = video_outputs.pop("input_ids")
                video_placeholder = processor.tokenizer.batch_decode(input_ids)[0]
                prompt = prompt.replace(
                    "<|vision_start|><|video_pad|><|vision_end|>",
                    video_placeholder,
                    1,
                )

                video_grid_thw_lst.append(video_outputs["video_grid_thw"])
                pixel_values_videos_lst.append(video_outputs["pixel_values_videos"])
            video_outputs = dict(
                pixel_values_videos=torch.cat(pixel_values_videos_lst),
                video_grid_thw=torch.cat(video_grid_thw_lst),
                timestamps=timestamps_per_video,
            )
        else:
            video_outputs = dict()

        processed_outputs = super()._call_hf_processor(
            prompt=prompt,
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
            tok_kwargs=tok_kwargs,
        )
        combined_outputs = dict(
            processed_outputs,
            **video_outputs,
        )
        return BatchFeature(combined_outputs)

    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return _create_qwen2vl_field_factory(
            self.info.get_hf_config().vision_config.spatial_merge_size
        )(hf_inputs)

    def _get_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, Any],
        out_mm_kwargs: MultiModalKwargsItems,
    ) -> Sequence[PromptUpdate]:
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
        image_processor = self.info.get_image_processor(**hf_processor_mm_kwargs)
        tokenizer = self.info.get_tokenizer()
        hf_config = self.info.get_hf_config()

        video_token_id = hf_config.video_token_id
        vision_start_token_id = hf_config.vision_start_token_id
        vision_end_token_id = hf_config.vision_end_token_id

        merge_length = image_processor.merge_size**2

        def get_image_replacement_qwen3vl(item_idx: int):
            out_item = out_mm_kwargs["image"][item_idx]
            grid_thw = out_item["image_grid_thw"].data
            assert isinstance(grid_thw, torch.Tensor)

            num_tokens = int(grid_thw.prod()) // merge_length
            return [hf_processor.image_token_id] * num_tokens

        def get_video_replacement_qwen3vl(item_idx: int):
            out_item = out_mm_kwargs["video"][item_idx]
            grid_thw = out_item["video_grid_thw"].data
            assert isinstance(grid_thw, torch.Tensor)

            sampled_fps = hf_processor_mm_kwargs.get("fps")
            if is_list_of(sampled_fps, float):
                sampled_fps = sampled_fps[item_idx]

            timestamps = out_item["timestamps"].data
            assert len(timestamps) == grid_thw[0], (
                f"The timestamps length({len(timestamps)}) should be equal "
                f"video length ({grid_thw[0]})."
            )

            # Compute tokens per frame, with EVS support
            num_frames = int(grid_thw[0])
            tokens_per_frame_base = int(grid_thw[1:].prod()) // merge_length

            video_pruning_rate = self.info.ctx.get_mm_config().video_pruning_rate
            if video_pruning_rate is not None and video_pruning_rate > 0.0:
                num_tokens = compute_retained_tokens_count(
                    tokens_per_frame=tokens_per_frame_base,
                    num_frames=num_frames,
                    q=video_pruning_rate,
                )
                tokens_per_frame = [num_tokens] + [0] * (num_frames - 1)
                select_token_id = False
            else:
                tokens_per_frame = [tokens_per_frame_base] * num_frames
                select_token_id = True

            return Qwen3VLMultiModalProcessor.get_video_repl(
                tokens_per_frame=tokens_per_frame,
                timestamps=timestamps,
                tokenizer=tokenizer,
                vision_start_token_id=vision_start_token_id,
                vision_end_token_id=vision_end_token_id,
                video_token_id=video_token_id,
                select_token_id=select_token_id,
            )

        return [
            PromptReplacement(
                modality="image",
                target=hf_processor.image_token,
                replacement=get_image_replacement_qwen3vl,
            ),
            # NOTE: We match string on purpose since searching sequence of
            # token ids takes more time.
            PromptReplacement(
                modality="video",
                target="<|vision_start|><|video_pad|><|vision_end|>",
                replacement=get_video_replacement_qwen3vl,
            ),
        ]

    @staticmethod
    def get_video_repl(
        *,
        tokens_per_frame: list[int],
        timestamps: list[float | int],
        tokenizer: TokenizerLike,
        vision_start_token_id: int,
        vision_end_token_id: int,
        video_token_id: int,
        select_token_id: bool = False,
    ) -> PromptUpdateDetails[list[int]]:
        """Build prompt replacement for a video in Qwen3VL format.

        The replacement structure for each frame is:
        timestamp_tokens + vision_start_token + video_tokens + vision_end_token

        Args:
            tokens_per_frame: Number of video tokens per frame (can vary per frame for
                EVS).
            timestamps: List of timestamps in seconds for each frame
            tokenizer: Tokenizer to encode timestamp strings
            vision_start_token_id: Token ID for vision start marker
            vision_end_token_id: Token ID for vision end marker
            video_token_id: Token ID for video content

        Returns:
            PromptUpdateDetails with full token sequence
        """
        assert len(timestamps) == len(tokens_per_frame), (
            "timestamps and tokens_per_frame must have the same length"
        )

        # Tokenize timestamp strings independently to avoid tokenizer merging
        # tokens across boundaries.
        # TODO: switch to `_seq2tokens` which has some caching.
        timestamp_token_ids = [
            tokenizer.encode(f"<{timestamp:.1f} seconds>", add_special_tokens=False)
            for timestamp in timestamps
        ]

        # Build the full token sequence
        all_token_ids = []
        for frame_timestamp_ids, num_tokens in zip(
            timestamp_token_ids, tokens_per_frame
        ):
            # Add timestamp tokens
            all_token_ids.extend(frame_timestamp_ids)

            # Add vision tokens: vision_start + video_tokens + vision_end
            all_token_ids.append(vision_start_token_id)
            all_token_ids.extend([video_token_id] * num_tokens)
            all_token_ids.append(vision_end_token_id)

        if select_token_id:
            return PromptUpdateDetails.select_token_id(all_token_ids, video_token_id)

        # NOTE: we use `from_seq` instead of `select_token_id` because we want all
        # tokens in the placeholder to be initially marked as candidates. Then
        # in `get_input_embeddings``, we refine the mask to only replace
        # `video_token_id` / `image_token_id`` positions with video/image embeddings,
        # keeping text embeddings for timestamps and structural tokens.
        return PromptUpdateDetails.from_seq(all_token_ids)

get_video_repl staticmethod

get_video_repl(
    *,
    tokens_per_frame: list[int],
    timestamps: list[float | int],
    tokenizer: TokenizerLike,
    vision_start_token_id: int,
    vision_end_token_id: int,
    video_token_id: int,
    select_token_id: bool = False,
) -> PromptUpdateDetails[list[int]]

Build prompt replacement for a video in Qwen3VL format.

The replacement structure for each frame is: timestamp_tokens + vision_start_token + video_tokens + vision_end_token

Parameters:

Name Type Description Default
tokens_per_frame list[int]

Number of video tokens per frame (can vary per frame for EVS).

required
timestamps list[float | int]

List of timestamps in seconds for each frame

required
tokenizer TokenizerLike

Tokenizer to encode timestamp strings

required
vision_start_token_id int

Token ID for vision start marker

required
vision_end_token_id int

Token ID for vision end marker

required
video_token_id int

Token ID for video content

required

Returns:

Type Description
PromptUpdateDetails[list[int]]

PromptUpdateDetails with full token sequence

Source code in vllm/model_executor/models/qwen3_vl.py
@staticmethod
def get_video_repl(
    *,
    tokens_per_frame: list[int],
    timestamps: list[float | int],
    tokenizer: TokenizerLike,
    vision_start_token_id: int,
    vision_end_token_id: int,
    video_token_id: int,
    select_token_id: bool = False,
) -> PromptUpdateDetails[list[int]]:
    """Build prompt replacement for a video in Qwen3VL format.

    The replacement structure for each frame is:
    timestamp_tokens + vision_start_token + video_tokens + vision_end_token

    Args:
        tokens_per_frame: Number of video tokens per frame (can vary per frame for
            EVS).
        timestamps: List of timestamps in seconds for each frame
        tokenizer: Tokenizer to encode timestamp strings
        vision_start_token_id: Token ID for vision start marker
        vision_end_token_id: Token ID for vision end marker
        video_token_id: Token ID for video content

    Returns:
        PromptUpdateDetails with full token sequence
    """
    assert len(timestamps) == len(tokens_per_frame), (
        "timestamps and tokens_per_frame must have the same length"
    )

    # Tokenize timestamp strings independently to avoid tokenizer merging
    # tokens across boundaries.
    # TODO: switch to `_seq2tokens` which has some caching.
    timestamp_token_ids = [
        tokenizer.encode(f"<{timestamp:.1f} seconds>", add_special_tokens=False)
        for timestamp in timestamps
    ]

    # Build the full token sequence
    all_token_ids = []
    for frame_timestamp_ids, num_tokens in zip(
        timestamp_token_ids, tokens_per_frame
    ):
        # Add timestamp tokens
        all_token_ids.extend(frame_timestamp_ids)

        # Add vision tokens: vision_start + video_tokens + vision_end
        all_token_ids.append(vision_start_token_id)
        all_token_ids.extend([video_token_id] * num_tokens)
        all_token_ids.append(vision_end_token_id)

    if select_token_id:
        return PromptUpdateDetails.select_token_id(all_token_ids, video_token_id)

    # NOTE: we use `from_seq` instead of `select_token_id` because we want all
    # tokens in the placeholder to be initially marked as candidates. Then
    # in `get_input_embeddings``, we refine the mask to only replace
    # `video_token_id` / `image_token_id`` positions with video/image embeddings,
    # keeping text embeddings for timestamps and structural tokens.
    return PromptUpdateDetails.from_seq(all_token_ids)