Skip to content

vllm.distributed.device_communicators.shm_broadcast

MessageQueue

Source code in vllm/distributed/device_communicators/shm_broadcast.py
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
class MessageQueue:
    def __init__(
        self,
        n_reader,  # number of all readers
        n_local_reader,  # number of local readers through shared memory
        local_reader_ranks: list[int] | None = None,
        # Default of 24MiB chosen to be large enough to accommodate grammar
        # bitmask tensors for large batches (1024 requests).
        max_chunk_bytes: int = 1024 * 1024 * 24,
        max_chunks: int = 10,
        connect_ip: str | None = None,
    ):
        if local_reader_ranks is None:
            local_reader_ranks = list(range(n_local_reader))
        else:
            assert len(local_reader_ranks) == n_local_reader
        self.n_local_reader = n_local_reader
        n_remote_reader = n_reader - n_local_reader
        self.n_remote_reader = n_remote_reader
        self.shutting_down = False
        context = Context()

        if n_local_reader > 0:
            # for local readers, we will:
            # 1. create a shared memory ring buffer to communicate small data
            # 2. create a publish-subscribe socket to communicate large data
            self.buffer = ShmRingBuffer(n_local_reader, max_chunk_bytes, max_chunks)

            # XPUB is very similar to PUB,
            # except that it can receive subscription messages
            # to confirm the number of subscribers
            self.local_socket = context.socket(XPUB)
            # set the verbose option so that we can receive every subscription
            # message. otherwise, we will only receive the first subscription
            # see http://api.zeromq.org/3-3:zmq-setsockopt for more details
            self.local_socket.setsockopt(XPUB_VERBOSE, True)
            local_subscribe_addr = get_open_zmq_ipc_path()
            logger.debug("Binding to %s", local_subscribe_addr)
            self.local_socket.bind(local_subscribe_addr)

            self.current_idx = 0

            # Create the notification side of the SpinCondition
            local_notify_addr = get_open_zmq_ipc_path()
            self._spin_condition = SpinCondition(
                is_reader=False, context=context, notify_address=local_notify_addr
            )
        else:
            self.buffer = None  # type: ignore
            local_subscribe_addr = None
            self.local_socket = None
            self.current_idx = -1
            local_notify_addr = None
            self._spin_condition = None  # type: ignore

        remote_addr_ipv6 = False
        if n_remote_reader > 0:
            # for remote readers, we will:
            # create a publish-subscribe socket to communicate large data
            if not connect_ip:
                connect_ip = get_ip()
            self.remote_socket = context.socket(XPUB)
            self.remote_socket.setsockopt(XPUB_VERBOSE, True)
            remote_subscribe_port = get_open_port()
            if is_valid_ipv6_address(connect_ip):
                self.remote_socket.setsockopt(IPV6, 1)
                remote_addr_ipv6 = True
                connect_ip = f"[{connect_ip}]"
            socket_addr = f"tcp://{connect_ip}:{remote_subscribe_port}"
            self.remote_socket.bind(socket_addr)
            remote_subscribe_addr = f"tcp://{connect_ip}:{remote_subscribe_port}"
        else:
            remote_subscribe_addr = None
            self.remote_socket = None

        self._is_writer = True
        self._is_local_reader = False
        self.local_reader_rank = -1
        # rank does not matter for remote readers
        self._is_remote_reader = False

        self.handle = Handle(
            local_reader_ranks=local_reader_ranks,
            buffer_handle=self.buffer.handle() if self.buffer is not None else None,
            local_subscribe_addr=local_subscribe_addr,
            local_notify_addr=local_notify_addr,
            remote_subscribe_addr=remote_subscribe_addr,
            remote_addr_ipv6=remote_addr_ipv6,
        )

        logger.debug("vLLM message queue communication handle: %s", self.handle)

    def export_handle(self) -> Handle:
        return self.handle

    @staticmethod
    def create_from_handle(handle: Handle, rank) -> "MessageQueue":
        self = MessageQueue.__new__(MessageQueue)
        self.handle = handle
        self._is_writer = False

        context = Context()

        if rank in handle.local_reader_ranks:
            assert handle.buffer_handle is not None
            self.buffer = ShmRingBuffer(*handle.buffer_handle)
            self.current_idx = 0
            self.local_reader_rank = handle.local_reader_ranks.index(rank)
            self._is_local_reader = True
            self._is_remote_reader = False

            self.local_socket = context.socket(SUB)
            self.local_socket.setsockopt_string(SUBSCRIBE, "")
            socket_addr = handle.local_subscribe_addr
            logger.debug("Connecting to %s", socket_addr)
            self.local_socket.connect(socket_addr)

            self.remote_socket = None
            assert isinstance(handle.local_notify_addr, str)
            self._spin_condition = SpinCondition(
                is_reader=True, context=context, notify_address=handle.local_notify_addr
            )
        else:
            self.buffer = None  # type: ignore
            self.current_idx = -1
            self.local_reader_rank = -1
            self._is_local_reader = False
            self._is_remote_reader = True

            self.local_socket = None

            self.remote_socket = context.socket(SUB)
            self.remote_socket.setsockopt_string(SUBSCRIBE, "")
            if handle.remote_addr_ipv6:
                self.remote_socket.setsockopt(IPV6, 1)
            socket_addr = handle.remote_subscribe_addr
            logger.debug("Connecting to %s", socket_addr)
            self.remote_socket.connect(socket_addr)
            self._spin_condition = None  # type: ignore

        self.shutting_down = False
        return self

    def wait_until_ready(self):
        """This is a collective operation. All processes (including the
        readers and the writer) should call this function.
        """
        if self._is_writer:
            # wait for all readers to connect

            # local readers
            for i in range(self.n_local_reader):
                # wait for subscription messages from all local readers
                self.local_socket.recv()
            if self.n_local_reader > 0:
                # send a message to all local readers
                # to make sure the publish channel is working
                self.local_socket.send(b"READY")

            # remote readers
            for i in range(self.n_remote_reader):
                # wait for subscription messages from all remote readers
                self.remote_socket.recv()
            if self.n_remote_reader > 0:
                # send a message to all remote readers
                # to make sure the publish channel is working
                self.remote_socket.send(b"READY")
        elif self._is_local_reader:
            # wait for the writer to send a message
            recv = self.local_socket.recv()
            assert recv == b"READY"
        elif self._is_remote_reader:
            # wait for the writer to send a message
            recv = self.remote_socket.recv()
            assert recv == b"READY"

    def shutdown(self):
        """If this is an idle reader, wakes it up so it can clean up and shut
        down"""
        self.shutting_down = True
        if self._spin_condition is not None:
            self._spin_condition.cancel()

    @contextmanager
    def acquire_write(self, timeout: float | None = None):
        assert self._is_writer, "Only writers can acquire write"
        start_time = time.monotonic()
        n_warning = 1
        while True:
            with self.buffer.get_metadata(self.current_idx) as metadata_buffer:
                # Memory fence ensures we see the latest read flags from readers.
                # Without this, we may read stale flags from our CPU cache and
                # spin indefinitely even though readers have completed.
                memory_fence()
                read_count = sum(metadata_buffer[1:])
                written_flag = metadata_buffer[0]
                if written_flag and read_count != self.buffer.n_reader:
                    # this block is written and not read by all readers
                    # for writers, `self.current_idx` is the next block to write
                    # if this block is not ready to write,
                    # we need to wait until it is read by all readers

                    # Release the processor to other threads
                    sched_yield()

                    # if we time out, raise an exception
                    elapsed = time.monotonic() - start_time
                    if timeout is not None and elapsed > timeout:
                        raise TimeoutError

                    # if we wait for a long time, log a message
                    if elapsed > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning:
                        logger.info(
                            LONG_WAIT_TIME_LOG_MSG, VLLM_RINGBUFFER_WARNING_INTERVAL
                        )
                        n_warning += 1

                    continue
                # found a block that is either
                # (1) not written
                # (2) read by all readers

                # mark the block as not written
                metadata_buffer[0] = 0
                # let caller write to the buffer
                with self.buffer.get_data(self.current_idx) as buf:
                    yield buf

                # caller has written to the buffer
                # NOTE: order is important here
                # first set the read flags to 0
                # then set the written flag to 1
                # otherwise, the readers may think they already read the block
                for i in range(1, self.buffer.n_reader + 1):
                    # set read flag to 0, meaning it is not read yet
                    metadata_buffer[i] = 0
                # Memory fence here ensures the order of the buffer and flag
                # writes. This guarantees that when `metadata_buffer[0] = 1` is
                # visible to readers, `buf` can be completely ready. Without
                # this, some CPU architectures with weak ordering may incur
                # memory inconsistency.
                memory_fence()
                # mark the block as written
                metadata_buffer[0] = 1
                # Memory fence ensures the write is visible to readers on other cores
                # before we proceed. Without this, readers may spin indefinitely
                # waiting for a write that's stuck in our CPU's store buffer.
                memory_fence()
                self.current_idx = (self.current_idx + 1) % self.buffer.max_chunks
                break

    class ReadTimeoutWithWarnings:
        def __init__(self, timeout: float | None, should_warn: bool) -> None:
            self.started = time.monotonic()
            self.deadline = sys.maxsize if timeout is None else self.started + timeout

            # if should_warn, we need to wake up periodically to log
            self.warning_wait_time_ms: int | None = (
                VLLM_RINGBUFFER_WARNING_INTERVAL * 1000 if should_warn else None
            )

            self._should_warn = should_warn
            self.n_warning = 1
            self.timeout = timeout

        def timeout_ms(self) -> int | None:
            """Returns a timeout that is:
            - min(time to deadline, time to next warning) if we're logging warnings
            - time to deadline, if we're not logging warnings
            - None if the timeout is None and we're not logging warnings
            - raise TimeoutError if we are past the deadline
            """
            warning_wait_time = self.warning_wait_time_ms
            if self.timeout is None:
                return warning_wait_time

            time_left_ms = int((self.deadline - time.monotonic()) * 1000)
            if time_left_ms <= 0:
                raise TimeoutError

            if warning_wait_time and warning_wait_time < time_left_ms:
                return warning_wait_time

            return time_left_ms

        def should_warn(self) -> bool:
            """Returns true if it's time to log a warning for a timeout that is not
            indefinite"""
            if self._should_warn:
                elapsed = time.monotonic() - self.started
                if elapsed >= VLLM_RINGBUFFER_WARNING_INTERVAL * self.n_warning:
                    self.n_warning += 1
                    return True
            return False

    @contextmanager
    def acquire_read(
        self,
        timeout: float | None = None,
        indefinite: bool = False,
    ):
        assert self._is_local_reader, "Only readers can acquire read"
        read_timeout = self.ReadTimeoutWithWarnings(
            timeout=timeout, should_warn=not indefinite
        )
        with self.buffer.get_metadata(self.current_idx) as metadata_buffer:
            while True:
                # Memory fence ensures we see the latest writes from the writer.
                # Without this, we may read stale flags from our CPU cache
                # and spin indefinitely even though writer has updated them.
                memory_fence()
                read_flag = metadata_buffer[self.local_reader_rank + 1]
                written_flag = metadata_buffer[0]
                if not written_flag or read_flag:
                    # this block is either
                    # (1) not written
                    # (2) already read by this reader

                    # for readers, `self.current_idx` is the next block to read
                    # if this block is not ready,
                    # we need to wait until it is written
                    self._spin_condition.wait(timeout_ms=read_timeout.timeout_ms())

                    if self.shutting_down:
                        raise RuntimeError("cancelled")

                    # if we wait for a long time, log a message
                    if read_timeout.should_warn():
                        logger.info(
                            LONG_WAIT_TIME_LOG_MSG, VLLM_RINGBUFFER_WARNING_INTERVAL
                        )

                    continue
                # found a block that is not read by this reader
                # let caller read from the buffer
                with self.buffer.get_data(self.current_idx) as buf:
                    yield buf

                # caller has read from the buffer
                # set the read flag
                metadata_buffer[self.local_reader_rank + 1] = 1
                # Memory fence ensures the read flag is visible to the writer.
                # Without this, writer may not see our read completion and
                # could wait indefinitely for all readers to finish.
                memory_fence()
                self.current_idx = (self.current_idx + 1) % self.buffer.max_chunks

                self._spin_condition.record_read()
                break

    def enqueue(self, obj, timeout: float | None = None):
        """Write to message queue with optional timeout (in seconds)"""
        assert self._is_writer, "Only writers can enqueue"
        all_buffers: list[SizedBuffer] = [b""]
        total_bytes = 6  # 2 bytes for oob buffer count, 4 for main buffer size

        def oob_callback(buf: PickleBuffer) -> bool:
            raw_buf = buf.raw()
            if len(raw_buf) < 1024 * 1024:
                # In-line buffers smaller than 1MiB.
                return True
            all_buffers.append(raw_buf)
            nonlocal total_bytes
            total_bytes += len(raw_buf) + 4
            return False

        all_buffers[0] = pickle.dumps(
            obj, protocol=pickle.HIGHEST_PROTOCOL, buffer_callback=oob_callback
        )
        if self.n_local_reader > 0:
            if total_bytes + len(all_buffers[0]) >= self.buffer.max_chunk_bytes:
                with self.acquire_write(timeout) as buf:
                    buf[0] = 1  # overflow
                self.local_socket.send_multipart(all_buffers, copy=False)
            else:
                # Byte 0: 0
                # Bytes 1-2: Count of buffers
                # Then each buffer follows, preceded by 4 bytes containing its length:
                # [4 byte int L][L bytes of buffer content] ...
                with self.acquire_write(timeout) as buf:
                    buf[0] = 0  # not overflow
                    offset = 3
                    buf[1:offset] = to_bytes_big(len(all_buffers), 2)  # oob buf count
                    for buffer in all_buffers:
                        buf_len = len(buffer)
                        # prepend each buffer with 4 bytes containing its size.
                        buf_offset = offset + 4
                        buf[offset:buf_offset] = to_bytes_big(buf_len, 4)
                        buf[buf_offset : (offset := buf_offset + buf_len)] = buffer

            self._spin_condition.notify()

        if self.n_remote_reader > 0:
            self.remote_socket.send_multipart(all_buffers, copy=False)

    def dequeue(
        self,
        timeout: float | None = None,
        indefinite: bool = False,
    ):
        """Read from message queue with optional timeout (in seconds)"""
        if self._is_local_reader:
            with self.acquire_read(timeout, indefinite) as buf:
                overflow = buf[0] == 1
                if not overflow:
                    offset = 3
                    buf_count = from_bytes_big(buf[1:offset])
                    all_buffers = []
                    for i in range(buf_count):
                        buf_offset = offset + 4
                        buf_len = from_bytes_big(buf[offset:buf_offset])
                        offset = buf_offset + buf_len
                        all_buffers.append(buf[buf_offset:offset])
                    obj = pickle.loads(all_buffers[0], buffers=all_buffers[1:])
            if overflow:
                obj = MessageQueue.recv(self.local_socket, timeout)
        elif self._is_remote_reader:
            obj = MessageQueue.recv(self.remote_socket, timeout)
        else:
            raise RuntimeError("Only readers can dequeue")
        return obj

    @staticmethod
    def recv(socket: zmq.Socket, timeout: float | None) -> Any:
        timeout_ms = None if timeout is None else int(timeout * 1000)
        if not socket.poll(timeout=timeout_ms):
            raise TimeoutError
        recv, *recv_oob = socket.recv_multipart(copy=False)
        return pickle.loads(recv, buffers=recv_oob)

    def broadcast_object(self, obj=None):
        if self._is_writer:
            self.enqueue(obj)
            return obj
        return self.dequeue()

    @staticmethod
    def create_from_process_group_single_reader(
        pg: ProcessGroup,
        max_chunk_bytes,
        max_chunks,
        reader_rank: int = 0,
        blocking: bool = False,
    ) -> tuple["MessageQueue", list[Handle]]:
        """
        Creates a MessageQueue for a process group with a single reader.

        This method is designed for scenarios where only one process (the reader)
        will consume messages, and all other processes are writers. It sets up
        the shared memory buffer and communication handles accordingly, and
        gathers the handles from all processes to the reader.

        Args:
            pg (ProcessGroup): The torch distributed process group.
            max_chunk_bytes (int): Maximum size in bytes for each chunk in the buffer.
            max_chunks (int): Maximum number of chunks in the buffer.
            reader_rank (int, optional): The global rank that will act as the reader.
                Defaults to 0.
            blocking (bool, optional): If True, blocks until all processes are ready.
                Defaults to False.

        Returns:
            tuple[MessageQueue, list[Handle]]:
            The MessageQueue instance for the calling process,
            and a list of handles (only non-empty for the reader process).
        """
        local_size = current_platform.device_count()
        rank = dist.get_rank()
        same_node = rank // local_size == reader_rank // local_size
        buffer_io = MessageQueue(
            n_reader=1,
            n_local_reader=1 if same_node else 0,
            max_chunk_bytes=max_chunk_bytes,
            max_chunks=max_chunks,
        )
        handle = buffer_io.export_handle()
        handles = [None] * dist.get_world_size(pg) if rank == reader_rank else None
        dist.gather_object(handle, handles, dst=reader_rank, group=pg)
        if blocking:
            buffer_io.wait_until_ready()
        return buffer_io, cast(list[Handle], handles or [])

    @staticmethod
    def create_from_process_group(
        pg: ProcessGroup | StatelessProcessGroup,
        max_chunk_bytes,
        max_chunks,
        writer_rank: int = 0,
        external_writer_handle=None,
        blocking: bool = True,
    ) -> "MessageQueue":
        """
        Creates a MessageQueue for a distributed process group with one writer and
        multiple readers.

        This method is designed for scenarios where one process (the writer) sends
        messages, and all other processes (the readers) receive messages. It sets up
        the shared memory buffer and socket communication handles accordingly, and
        broadcasts the handle from the writer to all readers.

        Args:
            pg (ProcessGroup | StatelessProcessGroup): The torch distributed process
                group.
            max_chunk_bytes (int): Maximum size in bytes for each chunk in the buffer.
            max_chunks (int): Maximum number of chunks in the buffer.
            writer_rank (int, optional): The global rank that will act as the writer.
                Defaults to 0.
            external_writer_handle (Handle, optional): Used when there is a handle
                from an external Message Queue. If provided, use this handle to init
                PG writer message queue instead of creating a new one. Defaults to None.
            blocking (bool, optional): If True, blocks until all processes are ready.
                Defaults to True.

        Returns:
            MessageQueue: The MessageQueue instance for the calling process.

        """
        if isinstance(pg, ProcessGroup):
            group_rank = dist.get_rank(pg)
            group_world_size = dist.get_world_size(pg)
            global_ranks = dist.get_process_group_ranks(pg)
        else:
            group_rank = pg.rank
            group_world_size = pg.world_size
            global_ranks = list(range(pg.world_size))
        from vllm.distributed.parallel_state import in_the_same_node_as

        status = in_the_same_node_as(pg, source_rank=writer_rank)
        if group_rank == writer_rank:
            if external_writer_handle is not None:
                buffer_io = MessageQueue.create_from_handle(
                    external_writer_handle, group_rank
                )
            else:
                same_node_ranks = [i for i, s in enumerate(status) if s]
                n_reader = group_world_size - 1
                n_local_reader = len(same_node_ranks) - 1
                local_reader_ranks = [i for i in same_node_ranks if i != writer_rank]
                buffer_io = MessageQueue(
                    n_reader=n_reader,
                    n_local_reader=n_local_reader,
                    local_reader_ranks=local_reader_ranks,
                    max_chunk_bytes=max_chunk_bytes,
                    max_chunks=max_chunks,
                )
            handle = buffer_io.export_handle()
            if isinstance(pg, ProcessGroup):
                dist.broadcast_object_list(
                    [handle], src=global_ranks[writer_rank], group=pg
                )
            else:
                pg.broadcast_obj(handle, writer_rank)
        else:
            if isinstance(pg, ProcessGroup):
                recv = [None]
                dist.broadcast_object_list(
                    recv, src=global_ranks[writer_rank], group=pg
                )
                handle = recv[0]  # type: ignore
            else:
                handle = pg.broadcast_obj(None, writer_rank)
            buffer_io = MessageQueue.create_from_handle(handle, group_rank)
        if blocking:
            buffer_io.wait_until_ready()
        return buffer_io

ReadTimeoutWithWarnings

Source code in vllm/distributed/device_communicators/shm_broadcast.py
class ReadTimeoutWithWarnings:
    def __init__(self, timeout: float | None, should_warn: bool) -> None:
        self.started = time.monotonic()
        self.deadline = sys.maxsize if timeout is None else self.started + timeout

        # if should_warn, we need to wake up periodically to log
        self.warning_wait_time_ms: int | None = (
            VLLM_RINGBUFFER_WARNING_INTERVAL * 1000 if should_warn else None
        )

        self._should_warn = should_warn
        self.n_warning = 1
        self.timeout = timeout

    def timeout_ms(self) -> int | None:
        """Returns a timeout that is:
        - min(time to deadline, time to next warning) if we're logging warnings
        - time to deadline, if we're not logging warnings
        - None if the timeout is None and we're not logging warnings
        - raise TimeoutError if we are past the deadline
        """
        warning_wait_time = self.warning_wait_time_ms
        if self.timeout is None:
            return warning_wait_time

        time_left_ms = int((self.deadline - time.monotonic()) * 1000)
        if time_left_ms <= 0:
            raise TimeoutError

        if warning_wait_time and warning_wait_time < time_left_ms:
            return warning_wait_time

        return time_left_ms

    def should_warn(self) -> bool:
        """Returns true if it's time to log a warning for a timeout that is not
        indefinite"""
        if self._should_warn:
            elapsed = time.monotonic() - self.started
            if elapsed >= VLLM_RINGBUFFER_WARNING_INTERVAL * self.n_warning:
                self.n_warning += 1
                return True
        return False

should_warn

should_warn() -> bool

Returns true if it's time to log a warning for a timeout that is not indefinite

Source code in vllm/distributed/device_communicators/shm_broadcast.py
def should_warn(self) -> bool:
    """Returns true if it's time to log a warning for a timeout that is not
    indefinite"""
    if self._should_warn:
        elapsed = time.monotonic() - self.started
        if elapsed >= VLLM_RINGBUFFER_WARNING_INTERVAL * self.n_warning:
            self.n_warning += 1
            return True
    return False

timeout_ms

timeout_ms() -> int | None

Returns a timeout that is: - min(time to deadline, time to next warning) if we're logging warnings - time to deadline, if we're not logging warnings - None if the timeout is None and we're not logging warnings - raise TimeoutError if we are past the deadline

Source code in vllm/distributed/device_communicators/shm_broadcast.py
def timeout_ms(self) -> int | None:
    """Returns a timeout that is:
    - min(time to deadline, time to next warning) if we're logging warnings
    - time to deadline, if we're not logging warnings
    - None if the timeout is None and we're not logging warnings
    - raise TimeoutError if we are past the deadline
    """
    warning_wait_time = self.warning_wait_time_ms
    if self.timeout is None:
        return warning_wait_time

    time_left_ms = int((self.deadline - time.monotonic()) * 1000)
    if time_left_ms <= 0:
        raise TimeoutError

    if warning_wait_time and warning_wait_time < time_left_ms:
        return warning_wait_time

    return time_left_ms

create_from_process_group staticmethod

create_from_process_group(
    pg: ProcessGroup | StatelessProcessGroup,
    max_chunk_bytes,
    max_chunks,
    writer_rank: int = 0,
    external_writer_handle=None,
    blocking: bool = True,
) -> MessageQueue

Creates a MessageQueue for a distributed process group with one writer and multiple readers.

This method is designed for scenarios where one process (the writer) sends messages, and all other processes (the readers) receive messages. It sets up the shared memory buffer and socket communication handles accordingly, and broadcasts the handle from the writer to all readers.

Parameters:

Name Type Description Default
pg ProcessGroup | StatelessProcessGroup

The torch distributed process group.

required
max_chunk_bytes int

Maximum size in bytes for each chunk in the buffer.

required
max_chunks int

Maximum number of chunks in the buffer.

required
writer_rank int

The global rank that will act as the writer. Defaults to 0.

0
external_writer_handle Handle

Used when there is a handle from an external Message Queue. If provided, use this handle to init PG writer message queue instead of creating a new one. Defaults to None.

None
blocking bool

If True, blocks until all processes are ready. Defaults to True.

True

Returns:

Name Type Description
MessageQueue MessageQueue

The MessageQueue instance for the calling process.

Source code in vllm/distributed/device_communicators/shm_broadcast.py
@staticmethod
def create_from_process_group(
    pg: ProcessGroup | StatelessProcessGroup,
    max_chunk_bytes,
    max_chunks,
    writer_rank: int = 0,
    external_writer_handle=None,
    blocking: bool = True,
) -> "MessageQueue":
    """
    Creates a MessageQueue for a distributed process group with one writer and
    multiple readers.

    This method is designed for scenarios where one process (the writer) sends
    messages, and all other processes (the readers) receive messages. It sets up
    the shared memory buffer and socket communication handles accordingly, and
    broadcasts the handle from the writer to all readers.

    Args:
        pg (ProcessGroup | StatelessProcessGroup): The torch distributed process
            group.
        max_chunk_bytes (int): Maximum size in bytes for each chunk in the buffer.
        max_chunks (int): Maximum number of chunks in the buffer.
        writer_rank (int, optional): The global rank that will act as the writer.
            Defaults to 0.
        external_writer_handle (Handle, optional): Used when there is a handle
            from an external Message Queue. If provided, use this handle to init
            PG writer message queue instead of creating a new one. Defaults to None.
        blocking (bool, optional): If True, blocks until all processes are ready.
            Defaults to True.

    Returns:
        MessageQueue: The MessageQueue instance for the calling process.

    """
    if isinstance(pg, ProcessGroup):
        group_rank = dist.get_rank(pg)
        group_world_size = dist.get_world_size(pg)
        global_ranks = dist.get_process_group_ranks(pg)
    else:
        group_rank = pg.rank
        group_world_size = pg.world_size
        global_ranks = list(range(pg.world_size))
    from vllm.distributed.parallel_state import in_the_same_node_as

    status = in_the_same_node_as(pg, source_rank=writer_rank)
    if group_rank == writer_rank:
        if external_writer_handle is not None:
            buffer_io = MessageQueue.create_from_handle(
                external_writer_handle, group_rank
            )
        else:
            same_node_ranks = [i for i, s in enumerate(status) if s]
            n_reader = group_world_size - 1
            n_local_reader = len(same_node_ranks) - 1
            local_reader_ranks = [i for i in same_node_ranks if i != writer_rank]
            buffer_io = MessageQueue(
                n_reader=n_reader,
                n_local_reader=n_local_reader,
                local_reader_ranks=local_reader_ranks,
                max_chunk_bytes=max_chunk_bytes,
                max_chunks=max_chunks,
            )
        handle = buffer_io.export_handle()
        if isinstance(pg, ProcessGroup):
            dist.broadcast_object_list(
                [handle], src=global_ranks[writer_rank], group=pg
            )
        else:
            pg.broadcast_obj(handle, writer_rank)
    else:
        if isinstance(pg, ProcessGroup):
            recv = [None]
            dist.broadcast_object_list(
                recv, src=global_ranks[writer_rank], group=pg
            )
            handle = recv[0]  # type: ignore
        else:
            handle = pg.broadcast_obj(None, writer_rank)
        buffer_io = MessageQueue.create_from_handle(handle, group_rank)
    if blocking:
        buffer_io.wait_until_ready()
    return buffer_io

create_from_process_group_single_reader staticmethod

create_from_process_group_single_reader(
    pg: ProcessGroup,
    max_chunk_bytes,
    max_chunks,
    reader_rank: int = 0,
    blocking: bool = False,
) -> tuple[MessageQueue, list[Handle]]

Creates a MessageQueue for a process group with a single reader.

This method is designed for scenarios where only one process (the reader) will consume messages, and all other processes are writers. It sets up the shared memory buffer and communication handles accordingly, and gathers the handles from all processes to the reader.

Parameters:

Name Type Description Default
pg ProcessGroup

The torch distributed process group.

required
max_chunk_bytes int

Maximum size in bytes for each chunk in the buffer.

required
max_chunks int

Maximum number of chunks in the buffer.

required
reader_rank int

The global rank that will act as the reader. Defaults to 0.

0
blocking bool

If True, blocks until all processes are ready. Defaults to False.

False

Returns:

Type Description
MessageQueue

tuple[MessageQueue, list[Handle]]:

list[Handle]

The MessageQueue instance for the calling process,

tuple[MessageQueue, list[Handle]]

and a list of handles (only non-empty for the reader process).

Source code in vllm/distributed/device_communicators/shm_broadcast.py
@staticmethod
def create_from_process_group_single_reader(
    pg: ProcessGroup,
    max_chunk_bytes,
    max_chunks,
    reader_rank: int = 0,
    blocking: bool = False,
) -> tuple["MessageQueue", list[Handle]]:
    """
    Creates a MessageQueue for a process group with a single reader.

    This method is designed for scenarios where only one process (the reader)
    will consume messages, and all other processes are writers. It sets up
    the shared memory buffer and communication handles accordingly, and
    gathers the handles from all processes to the reader.

    Args:
        pg (ProcessGroup): The torch distributed process group.
        max_chunk_bytes (int): Maximum size in bytes for each chunk in the buffer.
        max_chunks (int): Maximum number of chunks in the buffer.
        reader_rank (int, optional): The global rank that will act as the reader.
            Defaults to 0.
        blocking (bool, optional): If True, blocks until all processes are ready.
            Defaults to False.

    Returns:
        tuple[MessageQueue, list[Handle]]:
        The MessageQueue instance for the calling process,
        and a list of handles (only non-empty for the reader process).
    """
    local_size = current_platform.device_count()
    rank = dist.get_rank()
    same_node = rank // local_size == reader_rank // local_size
    buffer_io = MessageQueue(
        n_reader=1,
        n_local_reader=1 if same_node else 0,
        max_chunk_bytes=max_chunk_bytes,
        max_chunks=max_chunks,
    )
    handle = buffer_io.export_handle()
    handles = [None] * dist.get_world_size(pg) if rank == reader_rank else None
    dist.gather_object(handle, handles, dst=reader_rank, group=pg)
    if blocking:
        buffer_io.wait_until_ready()
    return buffer_io, cast(list[Handle], handles or [])

dequeue

dequeue(
    timeout: float | None = None, indefinite: bool = False
)

Read from message queue with optional timeout (in seconds)

Source code in vllm/distributed/device_communicators/shm_broadcast.py
def dequeue(
    self,
    timeout: float | None = None,
    indefinite: bool = False,
):
    """Read from message queue with optional timeout (in seconds)"""
    if self._is_local_reader:
        with self.acquire_read(timeout, indefinite) as buf:
            overflow = buf[0] == 1
            if not overflow:
                offset = 3
                buf_count = from_bytes_big(buf[1:offset])
                all_buffers = []
                for i in range(buf_count):
                    buf_offset = offset + 4
                    buf_len = from_bytes_big(buf[offset:buf_offset])
                    offset = buf_offset + buf_len
                    all_buffers.append(buf[buf_offset:offset])
                obj = pickle.loads(all_buffers[0], buffers=all_buffers[1:])
        if overflow:
            obj = MessageQueue.recv(self.local_socket, timeout)
    elif self._is_remote_reader:
        obj = MessageQueue.recv(self.remote_socket, timeout)
    else:
        raise RuntimeError("Only readers can dequeue")
    return obj

enqueue

enqueue(obj, timeout: float | None = None)

Write to message queue with optional timeout (in seconds)

Source code in vllm/distributed/device_communicators/shm_broadcast.py
def enqueue(self, obj, timeout: float | None = None):
    """Write to message queue with optional timeout (in seconds)"""
    assert self._is_writer, "Only writers can enqueue"
    all_buffers: list[SizedBuffer] = [b""]
    total_bytes = 6  # 2 bytes for oob buffer count, 4 for main buffer size

    def oob_callback(buf: PickleBuffer) -> bool:
        raw_buf = buf.raw()
        if len(raw_buf) < 1024 * 1024:
            # In-line buffers smaller than 1MiB.
            return True
        all_buffers.append(raw_buf)
        nonlocal total_bytes
        total_bytes += len(raw_buf) + 4
        return False

    all_buffers[0] = pickle.dumps(
        obj, protocol=pickle.HIGHEST_PROTOCOL, buffer_callback=oob_callback
    )
    if self.n_local_reader > 0:
        if total_bytes + len(all_buffers[0]) >= self.buffer.max_chunk_bytes:
            with self.acquire_write(timeout) as buf:
                buf[0] = 1  # overflow
            self.local_socket.send_multipart(all_buffers, copy=False)
        else:
            # Byte 0: 0
            # Bytes 1-2: Count of buffers
            # Then each buffer follows, preceded by 4 bytes containing its length:
            # [4 byte int L][L bytes of buffer content] ...
            with self.acquire_write(timeout) as buf:
                buf[0] = 0  # not overflow
                offset = 3
                buf[1:offset] = to_bytes_big(len(all_buffers), 2)  # oob buf count
                for buffer in all_buffers:
                    buf_len = len(buffer)
                    # prepend each buffer with 4 bytes containing its size.
                    buf_offset = offset + 4
                    buf[offset:buf_offset] = to_bytes_big(buf_len, 4)
                    buf[buf_offset : (offset := buf_offset + buf_len)] = buffer

        self._spin_condition.notify()

    if self.n_remote_reader > 0:
        self.remote_socket.send_multipart(all_buffers, copy=False)

shutdown

shutdown()

If this is an idle reader, wakes it up so it can clean up and shut down

Source code in vllm/distributed/device_communicators/shm_broadcast.py
def shutdown(self):
    """If this is an idle reader, wakes it up so it can clean up and shut
    down"""
    self.shutting_down = True
    if self._spin_condition is not None:
        self._spin_condition.cancel()

wait_until_ready

wait_until_ready()

This is a collective operation. All processes (including the readers and the writer) should call this function.

Source code in vllm/distributed/device_communicators/shm_broadcast.py
def wait_until_ready(self):
    """This is a collective operation. All processes (including the
    readers and the writer) should call this function.
    """
    if self._is_writer:
        # wait for all readers to connect

        # local readers
        for i in range(self.n_local_reader):
            # wait for subscription messages from all local readers
            self.local_socket.recv()
        if self.n_local_reader > 0:
            # send a message to all local readers
            # to make sure the publish channel is working
            self.local_socket.send(b"READY")

        # remote readers
        for i in range(self.n_remote_reader):
            # wait for subscription messages from all remote readers
            self.remote_socket.recv()
        if self.n_remote_reader > 0:
            # send a message to all remote readers
            # to make sure the publish channel is working
            self.remote_socket.send(b"READY")
    elif self._is_local_reader:
        # wait for the writer to send a message
        recv = self.local_socket.recv()
        assert recv == b"READY"
    elif self._is_remote_reader:
        # wait for the writer to send a message
        recv = self.remote_socket.recv()
        assert recv == b"READY"

ShmRingBuffer

Source code in vllm/distributed/device_communicators/shm_broadcast.py
class ShmRingBuffer:
    def __init__(
        self,
        n_reader: int,
        max_chunk_bytes: int,
        max_chunks: int,
        name: str | None = None,
    ):
        """
        A shared memory ring buffer implementation for broadcast communication.
        Essentially, it is a queue where only one will `enqueue` and multiple
        will `dequeue`. The max size of each item, together with the max number
        of items that can be stored in the buffer are known in advance.
        In this case, we don't need to synchronize the access to
         the buffer.

        Buffer memory layout:
                  data                                 metadata
                    |                                      |
                    | (current_idx)                        | (current_idx)
                    v                                      v
        +-------------------------------+----------------------------------------+
        | chunk0 | chunk1 | ... | chunk | metadata0 | metadata1 | ... | metadata |
        +-------------------------------+----------------------------------------+
        | max_chunks x max_chunk_bytes  | max_chunks x (1 + n_reader) bytes      |

        metadata memory layout: each byte is a flag, the first byte is the written
        flag, and the rest are reader flags. The flags are set to 0 by default.
        +--------------+--------------+--------------+-----+--------------+
        | written_flag | reader0_flag | reader1_flag | ... | readerN_flag |
        +--------------+--------------+--------------+-----+--------------+

        The state of metadata is as follows:

        (case 1) 0???...???: the block is not written yet, cannot read, can write
        (case 2) 1000...000: the block is just written, can read, cannot write
        (case 3) 1???...???: the block is written and read by some readers, can read if not read, cannot write
        (case 4) 1111...111: the block is written and read by all readers, cannot read, can write

        State transition for readers:

        When a reader finds a block that it can read (case 2 or 3), it can yield the block for caller to read.
        Only after the caller finishes reading the block, the reader can mark the block as read.
        Readers only mark the block as read (from 0 to 1), the writer marks the block as ready to read (from 1 to 0).

        State transition for writer:

        When the writer writes to a block (case 1 or 4), it first resets the written flag to 0, converting either case
        to case 1. Then it can yield the block for caller to write. After the caller finishes writing the block, the writer
        can reset the reader flags to 0, and mark the block as written (from 0 to 1).
        NOTE: the order is important here, first reset the reader flags (so that we are still in case 1), then mark the block as written. The state transition is atomic. If we do it in the reverse order, it will go through case 3 and then back to case 2, and readers might read the intermediate case 3, which is not correct.

        During creation, `name` is None and the buffer is created. We can pass the
        created object to other processes by pickling it. The other processes will
        get the name of the shared memory and open it, so that they can access the
        same shared memory buffer.
        """  # noqa
        self.n_reader = n_reader
        self.metadata_size = 1 + n_reader
        self.max_chunk_bytes = max_chunk_bytes
        self.max_chunks = max_chunks
        self.total_bytes_of_buffer = (
            self.max_chunk_bytes + self.metadata_size
        ) * self.max_chunks
        self.data_offset = 0
        self.metadata_offset = self.max_chunk_bytes * self.max_chunks

        if name is None:
            # we are creating a buffer
            self.is_creator = True
            self.shared_memory = shared_memory.SharedMemory(
                create=True, size=self.total_bytes_of_buffer
            )
            # initialize the metadata section to 0
            with self.shared_memory.buf[self.metadata_offset :] as metadata_buffer:
                torch.frombuffer(metadata_buffer, dtype=torch.uint8).fill_(0)
        else:
            # we are opening an existing buffer
            self.is_creator = False
            # fix to https://stackoverflow.com/q/62748654/9191338
            # Python incorrectly tracks shared memory even if it is not
            # created by the process. The following patch is a workaround.
            with patch(
                "multiprocessing.resource_tracker.register",
                lambda *args, **kwargs: None,
            ):
                try:
                    self.shared_memory = shared_memory.SharedMemory(name=name)
                    # See https://docs.python.org/3/library/multiprocessing.shared_memory.html # noqa
                    # Some platforms allocate memory based on page size,
                    # so the shared memory block size may be larger or equal
                    # to the requested size. The size parameter is ignored
                    # when attaching to an existing block.
                    assert self.shared_memory.size >= self.total_bytes_of_buffer
                except FileNotFoundError:
                    # we might deserialize the object in a different node
                    # in this case, this object is not used,
                    # and we should suppress the error
                    pass

    def handle(self):
        return (
            self.n_reader,
            self.max_chunk_bytes,
            self.max_chunks,
            self.shared_memory.name,
        )

    def __reduce__(self):
        return (
            self.__class__,
            self.handle(),
        )

    def __del__(self):
        if hasattr(self, "shared_memory"):
            self.shared_memory.close()
            if self.is_creator:
                self.shared_memory.unlink()

    @contextmanager
    def get_data(self, current_idx: int):
        start = self.data_offset + current_idx * self.max_chunk_bytes
        end = start + self.max_chunk_bytes
        with self.shared_memory.buf[start:end] as buf:
            yield buf

    @contextmanager
    def get_metadata(self, current_idx: int):
        start = self.metadata_offset + current_idx * self.metadata_size
        end = start + self.metadata_size
        with self.shared_memory.buf[start:end] as buf:
            yield buf

__init__

__init__(
    n_reader: int,
    max_chunk_bytes: int,
    max_chunks: int,
    name: str | None = None,
)

A shared memory ring buffer implementation for broadcast communication. Essentially, it is a queue where only one will enqueue and multiple will dequeue. The max size of each item, together with the max number of items that can be stored in the buffer are known in advance. In this case, we don't need to synchronize the access to the buffer.

Buffer memory layout

data metadata | | | (current_idx) | (current_idx) v v

+-------------------------------+----------------------------------------+ | chunk0 | chunk1 | ... | chunk | metadata0 | metadata1 | ... | metadata | +-------------------------------+----------------------------------------+ | max_chunks x max_chunk_bytes | max_chunks x (1 + n_reader) bytes |

metadata memory layout: each byte is a flag, the first byte is the written flag, and the rest are reader flags. The flags are set to 0 by default. +--------------+--------------+--------------+-----+--------------+ | written_flag | reader0_flag | reader1_flag | ... | readerN_flag | +--------------+--------------+--------------+-----+--------------+

The state of metadata is as follows:

(case 1) 0???...???: the block is not written yet, cannot read, can write (case 2) 1000...000: the block is just written, can read, cannot write (case 3) 1???...???: the block is written and read by some readers, can read if not read, cannot write (case 4) 1111...111: the block is written and read by all readers, cannot read, can write

State transition for readers:

When a reader finds a block that it can read (case 2 or 3), it can yield the block for caller to read. Only after the caller finishes reading the block, the reader can mark the block as read. Readers only mark the block as read (from 0 to 1), the writer marks the block as ready to read (from 1 to 0).

State transition for writer:

When the writer writes to a block (case 1 or 4), it first resets the written flag to 0, converting either case to case 1. Then it can yield the block for caller to write. After the caller finishes writing the block, the writer can reset the reader flags to 0, and mark the block as written (from 0 to 1). NOTE: the order is important here, first reset the reader flags (so that we are still in case 1), then mark the block as written. The state transition is atomic. If we do it in the reverse order, it will go through case 3 and then back to case 2, and readers might read the intermediate case 3, which is not correct.

During creation, name is None and the buffer is created. We can pass the created object to other processes by pickling it. The other processes will get the name of the shared memory and open it, so that they can access the same shared memory buffer.

Source code in vllm/distributed/device_communicators/shm_broadcast.py
def __init__(
    self,
    n_reader: int,
    max_chunk_bytes: int,
    max_chunks: int,
    name: str | None = None,
):
    """
    A shared memory ring buffer implementation for broadcast communication.
    Essentially, it is a queue where only one will `enqueue` and multiple
    will `dequeue`. The max size of each item, together with the max number
    of items that can be stored in the buffer are known in advance.
    In this case, we don't need to synchronize the access to
     the buffer.

    Buffer memory layout:
              data                                 metadata
                |                                      |
                | (current_idx)                        | (current_idx)
                v                                      v
    +-------------------------------+----------------------------------------+
    | chunk0 | chunk1 | ... | chunk | metadata0 | metadata1 | ... | metadata |
    +-------------------------------+----------------------------------------+
    | max_chunks x max_chunk_bytes  | max_chunks x (1 + n_reader) bytes      |

    metadata memory layout: each byte is a flag, the first byte is the written
    flag, and the rest are reader flags. The flags are set to 0 by default.
    +--------------+--------------+--------------+-----+--------------+
    | written_flag | reader0_flag | reader1_flag | ... | readerN_flag |
    +--------------+--------------+--------------+-----+--------------+

    The state of metadata is as follows:

    (case 1) 0???...???: the block is not written yet, cannot read, can write
    (case 2) 1000...000: the block is just written, can read, cannot write
    (case 3) 1???...???: the block is written and read by some readers, can read if not read, cannot write
    (case 4) 1111...111: the block is written and read by all readers, cannot read, can write

    State transition for readers:

    When a reader finds a block that it can read (case 2 or 3), it can yield the block for caller to read.
    Only after the caller finishes reading the block, the reader can mark the block as read.
    Readers only mark the block as read (from 0 to 1), the writer marks the block as ready to read (from 1 to 0).

    State transition for writer:

    When the writer writes to a block (case 1 or 4), it first resets the written flag to 0, converting either case
    to case 1. Then it can yield the block for caller to write. After the caller finishes writing the block, the writer
    can reset the reader flags to 0, and mark the block as written (from 0 to 1).
    NOTE: the order is important here, first reset the reader flags (so that we are still in case 1), then mark the block as written. The state transition is atomic. If we do it in the reverse order, it will go through case 3 and then back to case 2, and readers might read the intermediate case 3, which is not correct.

    During creation, `name` is None and the buffer is created. We can pass the
    created object to other processes by pickling it. The other processes will
    get the name of the shared memory and open it, so that they can access the
    same shared memory buffer.
    """  # noqa
    self.n_reader = n_reader
    self.metadata_size = 1 + n_reader
    self.max_chunk_bytes = max_chunk_bytes
    self.max_chunks = max_chunks
    self.total_bytes_of_buffer = (
        self.max_chunk_bytes + self.metadata_size
    ) * self.max_chunks
    self.data_offset = 0
    self.metadata_offset = self.max_chunk_bytes * self.max_chunks

    if name is None:
        # we are creating a buffer
        self.is_creator = True
        self.shared_memory = shared_memory.SharedMemory(
            create=True, size=self.total_bytes_of_buffer
        )
        # initialize the metadata section to 0
        with self.shared_memory.buf[self.metadata_offset :] as metadata_buffer:
            torch.frombuffer(metadata_buffer, dtype=torch.uint8).fill_(0)
    else:
        # we are opening an existing buffer
        self.is_creator = False
        # fix to https://stackoverflow.com/q/62748654/9191338
        # Python incorrectly tracks shared memory even if it is not
        # created by the process. The following patch is a workaround.
        with patch(
            "multiprocessing.resource_tracker.register",
            lambda *args, **kwargs: None,
        ):
            try:
                self.shared_memory = shared_memory.SharedMemory(name=name)
                # See https://docs.python.org/3/library/multiprocessing.shared_memory.html # noqa
                # Some platforms allocate memory based on page size,
                # so the shared memory block size may be larger or equal
                # to the requested size. The size parameter is ignored
                # when attaching to an existing block.
                assert self.shared_memory.size >= self.total_bytes_of_buffer
            except FileNotFoundError:
                # we might deserialize the object in a different node
                # in this case, this object is not used,
                # and we should suppress the error
                pass

SpinCondition

This class implements an interface similar to a threading.Condition. It allows a writer to notify readers to wake up and read from the shared memory buffer. This notification is done over a zmq socket.

For optimal performance under load we don't want the readers to need to poll the zmq socket for every read. So the wait method here will return immediately when reads are frequent, and will only enter "idle mode" and await a notification on the zmq socket after a period of inactivity. This allows the readers to spin quickly, hence "SpinCondition".

To support clean shutdown, a separate thread in the reader's process must be able to wake the reader so that it can exit. A separate cancel() method is implemented with an in-process socket to allow this interruption.

Source code in vllm/distributed/device_communicators/shm_broadcast.py
class SpinCondition:
    """
    This class implements an interface similar to a threading.Condition. It
    allows a writer to notify readers to wake up and read from the shared memory
    buffer. This notification is done over a zmq socket.

    For optimal performance under load we don't want the readers to need to poll
    the zmq socket for every read. So the `wait` method here will return
    immediately when reads are frequent, and will only enter "idle mode" and
    await a notification on the zmq socket after a period of inactivity. This
    allows the readers to spin quickly, hence "SpinCondition".

    To support clean shutdown, a separate thread in the reader's process must be
    able to wake the reader so that it can exit. A separate cancel() method is
    implemented with an in-process socket to allow this interruption.
    """

    def __init__(
        self,
        is_reader: bool,
        context: zmq.Context,
        notify_address: str,
        busy_loop_s: float = 1,
    ):
        self.is_reader = is_reader

        if is_reader:
            # Time of last shm buffer read
            self.last_read = time.monotonic()

            # Time to keep busy-looping on the shm buffer before going idle
            self.busy_loop_s = busy_loop_s

            # Readers subscribe to write notifications
            self.local_notify_socket: zmq.Socket = context.socket(SUB)
            # Set zmq.CONFLATE to only keep the last message that the socket
            # receives. This prevents us from piling up notification messages
            # under high load when we aren't polling the socket.
            self.local_notify_socket.setsockopt(zmq.CONFLATE, 1)
            # Subscribe to all messages on the socket
            self.local_notify_socket.setsockopt_string(SUBSCRIBE, "")
            self.local_notify_socket.connect(notify_address)

            # Readers require a process-local socket to poll for cancellation
            cancel_path = get_open_zmq_inproc_path()
            self.write_cancel_socket: zmq.Socket = context.socket(zmq.PAIR)
            self.write_cancel_socket.bind(cancel_path)
            self.read_cancel_socket: zmq.Socket = context.socket(zmq.PAIR)
            self.read_cancel_socket.connect(cancel_path)

            # Poller allows waiting on either `.notify()` or `.cancel()`
            self.poller = zmq.Poller()
            self.poller.register(self.read_cancel_socket, zmq.POLLIN)
            self.poller.register(self.local_notify_socket, zmq.POLLIN)
        else:
            # Writer side publishes write notifications
            self.local_notify_socket: zmq.Socket = context.socket(PUB)  # type: ignore
            # Set high water mark to 1 - we don't need to send a massive amount of
            # pings during busy operation. PUB sockets will silently drop subsequent
            # messages after the high water mark is reached.
            self.local_notify_socket.setsockopt(zmq.SNDHWM, 1)
            self.local_notify_socket.bind(notify_address)

            self.last_read = 0
            self.busy_loop_s = 0
            self.read_cancel_socket = None
            self.write_cancel_socket = None
            self.poller = None

    def record_read(self):
        self.last_read = time.monotonic()

    def cancel(self):
        # Sends cancellation ping that will cause the reader to wake up.
        # This is done from a monitor thread in the same process as the reader.
        if self.is_reader:
            logger.debug("Canceling waiting reads on SHM Buffer")
            self.write_cancel_socket.send(b"\x00")

    def wait(self, timeout_ms: int | None = None) -> None:
        """Wait for data on the shared memory buffer.

        Yields the scheduler then returns immediately if it has been less than
        self.busy_loop_s since the last read.

        Otherwise, enters idle mode and awaits a socket ping for at most
        `timeout_ms` milliseconds, or indefinitely if timeout_ms is None.
        """
        assert self.is_reader, "Only readers can wait"

        current_time = time.monotonic()
        if current_time <= self.last_read + self.busy_loop_s:
            sched_yield()
        else:
            events = dict(self.poller.poll(timeout=timeout_ms))

            if self.read_cancel_socket in events:
                logger.debug("Poller received cancel event")
            elif self.local_notify_socket in events:
                logger.debug("Poller received notify event")
                # Since zmq.CONFLATE is set, there will only be one notification
                # to read from the socket
                self.local_notify_socket.recv(flags=zmq.NOBLOCK, copy=False)
            else:
                logger.debug("Poller timed out")

    def notify(self):
        """Notifies all readers to wake up"""
        assert not self.is_reader, "Only writers can notify"
        self.local_notify_socket.send(b"\x00")

notify

notify()

Notifies all readers to wake up

Source code in vllm/distributed/device_communicators/shm_broadcast.py
def notify(self):
    """Notifies all readers to wake up"""
    assert not self.is_reader, "Only writers can notify"
    self.local_notify_socket.send(b"\x00")

wait

wait(timeout_ms: int | None = None) -> None

Wait for data on the shared memory buffer.

Yields the scheduler then returns immediately if it has been less than self.busy_loop_s since the last read.

Otherwise, enters idle mode and awaits a socket ping for at most timeout_ms milliseconds, or indefinitely if timeout_ms is None.

Source code in vllm/distributed/device_communicators/shm_broadcast.py
def wait(self, timeout_ms: int | None = None) -> None:
    """Wait for data on the shared memory buffer.

    Yields the scheduler then returns immediately if it has been less than
    self.busy_loop_s since the last read.

    Otherwise, enters idle mode and awaits a socket ping for at most
    `timeout_ms` milliseconds, or indefinitely if timeout_ms is None.
    """
    assert self.is_reader, "Only readers can wait"

    current_time = time.monotonic()
    if current_time <= self.last_read + self.busy_loop_s:
        sched_yield()
    else:
        events = dict(self.poller.poll(timeout=timeout_ms))

        if self.read_cancel_socket in events:
            logger.debug("Poller received cancel event")
        elif self.local_notify_socket in events:
            logger.debug("Poller received notify event")
            # Since zmq.CONFLATE is set, there will only be one notification
            # to read from the socket
            self.local_notify_socket.recv(flags=zmq.NOBLOCK, copy=False)
        else:
            logger.debug("Poller timed out")

memory_fence

memory_fence()

Full memory barrier for shared memory synchronization.

Ensures all prior memory writes are visible to other processes before any subsequent reads. This is critical for lock-free producer-consumer patterns using shared memory.

Implementation acquires and immediately releases a lock. Python's threading.Lock provides sequentially consistent memory barrier semantics across all major platforms (POSIX, Windows). This is a lightweight operation (~20ns) that guarantees: - All stores before the barrier are visible to other threads/processes - All loads after the barrier see the latest values

Source code in vllm/distributed/device_communicators/shm_broadcast.py
def memory_fence():
    """
    Full memory barrier for shared memory synchronization.

    Ensures all prior memory writes are visible to other processes before
    any subsequent reads. This is critical for lock-free producer-consumer
    patterns using shared memory.

    Implementation acquires and immediately releases a lock. Python's
    threading.Lock provides sequentially consistent memory barrier semantics
    across all major platforms (POSIX, Windows). This is a lightweight
    operation (~20ns) that guarantees:
    - All stores before the barrier are visible to other threads/processes
    - All loads after the barrier see the latest values
    """
    # Lock acquire/release provides full memory barrier semantics.
    # Using context manager ensures lock release even on exceptions.
    with _memory_fence_lock:
        pass