class WorkerProc:
"""Wrapper that runs one Worker in a separate process."""
READY_STR = "READY"
rpc_broadcast_mq: MessageQueue | None
worker_response_mq: MessageQueue | None
def _init_message_queues(
self, input_shm_handle: Handle, vllm_config: VllmConfig
) -> None:
if vllm_config.parallel_config.nnodes_within_dp == 1:
# Initialize MessageQueue for receiving SchedulerOutput
self.rpc_broadcast_mq = MessageQueue.create_from_handle(
input_shm_handle, self.worker.rank
)
# Initializes a message queue for sending the model output
self.worker_response_mq = MessageQueue(1, 1)
self.peer_response_handles = []
else:
# Initialize remote MessageQueue for receiving SchedulerOutput across nodes
self.rpc_broadcast_mq = get_inner_dp_world_group().create_mq_broadcaster(
external_writer_handle=input_shm_handle,
# Since there is external_writer_handle from executor proc,
# where the ready signal from actual writer is sent out of the
# create_mq_broadcaster method and after this setup, we make it
# non blocking. The handshake will be triggered when
# worker.rpc_broadcast_mq.wait_until_ready() is called
blocking=False,
)
# Initializes remote message queue for sending the model output to the
# driver worker, exposing peer_response_handles for driver worker
# that include handles for all ranks
self.worker_response_mq, self.peer_response_handles = (
get_inner_dp_world_group().create_single_reader_mq_broadcasters(
reader_rank_in_group=0
)
)
@instrument(span_name="Worker init")
def __init__(
self,
vllm_config: VllmConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
input_shm_handle: Handle,
shared_worker_lock: LockType,
is_driver_worker: bool,
):
self.rank = rank
wrapper = WorkerWrapperBase(rpc_rank=local_rank, global_rank=rank)
# TODO: move `init_worker` to executor level as a collective rpc call
all_kwargs: list[dict] = [
{} for _ in range(vllm_config.parallel_config.world_size)
]
all_kwargs[local_rank] = {
"vllm_config": vllm_config,
"local_rank": local_rank,
"rank": rank,
"distributed_init_method": distributed_init_method,
"is_driver_worker": is_driver_worker,
"shared_worker_lock": shared_worker_lock,
}
wrapper.init_worker(all_kwargs)
self.worker = wrapper
scheduler_config = vllm_config.scheduler_config
self.use_async_scheduling = scheduler_config.async_scheduling
if self.use_async_scheduling:
self.async_output_queue: queue.Queue = queue.Queue()
self.async_output_copy_thread = Thread(
target=self.async_output_busy_loop,
daemon=True,
name="WorkerAsyncOutputCopy",
)
self.async_output_copy_thread.start()
self.setup_proc_title_and_log_prefix(
enable_ep=vllm_config.parallel_config.enable_expert_parallel
)
# Load model
is_eep_new_worker = envs.VLLM_ELASTIC_EP_SCALE_UP_LAUNCH
if not is_eep_new_worker:
self.worker.init_device()
# Update process title now that parallel groups are initialized
self.setup_proc_title_and_log_prefix(
enable_ep=vllm_config.parallel_config.enable_expert_parallel
)
self.worker.load_model()
# Set block size based on the attention backends
current_platform.update_block_size_for_backend(vllm_config)
# Initialize message queues after init_device() since multi-node setups
# (nnodes_within_dp > 1) require distributed groups to be initialized
self._init_message_queues(input_shm_handle, vllm_config)
# Enable environment variable cache (e.g. assume no more
# environment variable overrides after this point)
enable_envs_cache()
@staticmethod
def make_worker_process(
vllm_config: VllmConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
input_shm_handle, # Receive SchedulerOutput
shared_worker_lock: LockType,
is_driver_worker: bool,
inherited_fds: list[int] | None = None,
) -> UnreadyWorkerProcHandle:
context = get_mp_context()
# Ready pipe to communicate readiness from child to parent
ready_reader, ready_writer = context.Pipe(duplex=False)
# Death pipe to let child detect parent process exit
death_reader, death_writer = context.Pipe(duplex=False)
if inherited_fds is not None:
inherited_fds = inherited_fds.copy()
inherited_fds.extend((ready_reader.fileno(), death_writer.fileno()))
process_kwargs = {
"vllm_config": vllm_config,
"local_rank": local_rank,
"rank": rank,
"distributed_init_method": distributed_init_method,
"input_shm_handle": input_shm_handle,
"ready_pipe": ready_writer,
"death_pipe": death_reader,
"shared_worker_lock": shared_worker_lock,
"is_driver_worker": is_driver_worker,
# Have the worker close parent end of this worker's pipes too
"inherited_fds": inherited_fds if inherited_fds is not None else [],
}
# Run EngineCore busy loop in background process.
proc = context.Process(
target=WorkerProc.worker_main,
kwargs=process_kwargs,
name=f"VllmWorker-{rank}",
daemon=True,
)
proc.start()
# Close child ends of pipes here in the parent
ready_writer.close()
death_reader.close()
# Keep death_writer open in parent - when parent exits,
# death_reader in child will get EOFError
return UnreadyWorkerProcHandle(proc, rank, ready_reader, death_writer)
@staticmethod
def wait_for_response_handle_ready(
handles: dict[str, Any], proc_handle: UnreadyWorkerProcHandle
) -> WorkerProcHandle:
response_handle = handles["handle"]
worker_response_mq: MessageQueue | None = None
if len(response_handle.local_reader_ranks) > 0:
worker_response_mq = MessageQueue.create_from_handle(response_handle, 0)
peer_response_handles = handles["peer_response_handles"]
peer_worker_response_mqs = [
MessageQueue.create_from_handle(handle, -1)
if handle.remote_subscribe_addr is not None
else None
for handle in peer_response_handles
]
return WorkerProcHandle.from_unready_handle(
proc_handle,
worker_response_mq,
peer_worker_response_mqs=peer_worker_response_mqs,
)
@staticmethod
def wait_for_ready(
unready_proc_handles: list[UnreadyWorkerProcHandle],
) -> list[WorkerProcHandle]:
e = Exception(
"WorkerProc initialization failed due to an exception in a "
"background process. See stack trace for root cause."
)
pipes = {handle.ready_pipe: handle for handle in unready_proc_handles}
ready_proc_handles: list[WorkerProcHandle | None] = [None] * len(
unready_proc_handles
)
while pipes:
ready = multiprocessing.connection.wait(pipes.keys())
for pipe in ready:
assert isinstance(pipe, Connection)
try:
# Wait until the WorkerProc is ready.
unready_proc_handle = pipes.pop(pipe)
response: dict[str, Any] = pipe.recv()
if response["status"] != "READY":
raise e
idx = unready_proc_handle.rank % len(ready_proc_handles)
ready_proc_handles[idx] = WorkerProc.wait_for_response_handle_ready(
response, unready_proc_handle
)
except EOFError:
e.__suppress_context__ = True
raise e from None
finally:
# Close connection.
pipe.close()
return cast(list[WorkerProcHandle], ready_proc_handles)
def shutdown(self):
if self.rpc_broadcast_mq is not None:
self.rpc_broadcast_mq.shutdown()
if self.worker_response_mq is not None:
self.worker_response_mq.shutdown()
self.worker.shutdown()
self.rpc_broadcast_mq = None
self.worker_response_mq = None
destroy_model_parallel()
destroy_distributed_environment()
def monitor_death_pipe(self, death_pipe, shutdown_requested: threading.Event):
if death_pipe is None:
return
def death_pipe_monitor(queues_to_shutdown: list[MessageQueue]):
try:
# This will block until parent process exits (pipe closes)
death_pipe.recv()
except EOFError:
logger.info_once("Parent process exited, terminating worker queues")
shutdown_requested.set()
for mq in queues_to_shutdown:
if mq is not None:
mq.shutdown()
except Exception as e:
logger.warning("Death monitoring error: %s", e)
# Pass queue references directly to avoid gc issues if passing self
Thread(
target=death_pipe_monitor,
args=([self.rpc_broadcast_mq, self.worker_response_mq],),
daemon=True,
name="DeathPipeMonitor",
).start()
@staticmethod
def worker_main(*args, **kwargs):
"""Worker initialization and execution loops.
This runs a background process"""
# Signal handler used for graceful termination.
# SystemExit exception is only raised once to allow this and worker
# processes to terminate without error
shutdown_requested = threading.Event()
def signal_handler(signum, frame):
nonlocal shutdown_requested
if not shutdown_requested.is_set():
shutdown_requested.set()
logger.debug(
"WorkerProc handling signal %d, raising SystemExit", signum
)
raise SystemExit()
# Either SIGTERM or SIGINT will terminate the worker
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
worker = None
ready_writer = kwargs.pop("ready_pipe")
death_pipe = kwargs.pop("death_pipe", None)
# Close inherited pipes from parent (incl. other worker pipes)
# Explicitly passing in existing pipes and closing them makes the pipe
# behave when using fork. Otherwise, a hidden reference to the pipes
# exist in the child process and prevents EOF closure.
for fd in kwargs.pop("inherited_fds", []):
try:
os.close(fd)
except Exception as e:
logger.warning("Error closing inherited connection: %s: %s", type(e), e)
try:
# Initialize tracer
rank = kwargs.get("rank", 0)
maybe_init_worker_tracer(
instrumenting_module_name="vllm.worker",
process_kind="worker",
process_name=f"Worker_{rank}",
)
worker = WorkerProc(*args, **kwargs)
assert worker.worker_response_mq is not None
worker.monitor_death_pipe(death_pipe, shutdown_requested)
# Send READY once we know everything is loaded
ready_writer.send(
{
"status": WorkerProc.READY_STR,
"handle": worker.worker_response_mq.export_handle(),
"peer_response_handles": worker.peer_response_handles,
}
)
# Ensure message queues are ready. Will deadlock if re-ordered.
# Must be kept consistent with the Executor
if worker.rpc_broadcast_mq is not None:
worker.rpc_broadcast_mq.wait_until_ready()
worker.worker_response_mq.wait_until_ready()
ready_writer.close()
ready_writer = None
worker.worker_busy_loop()
except Exception:
# NOTE: if an Exception arises in busy_loop, we send
# a FAILURE message over the MQ RPC to notify the Executor,
# which triggers system shutdown.
# TODO(rob): handle case where the MQ itself breaks.
if ready_writer is not None:
logger.exception("WorkerProc failed to start.")
elif shutdown_requested.is_set():
logger.info("WorkerProc shutting down.")
else:
logger.exception("WorkerProc failed.")
# The parent sends a SIGTERM to all worker processes if
# any worker dies. Set this value so we don't re-throw
# SystemExit() to avoid zmq exceptions in __del__.
shutdown_requested.set()
except SystemExit as e:
# SystemExit is raised on SIGTERM or SIGKILL, which usually indicates that
# the graceful shutdown process did not succeed
logger.warning("WorkerProc was terminated")
# SystemExit must never be ignored
raise e
finally:
if ready_writer is not None:
ready_writer.close()
if death_pipe is not None:
death_pipe.close()
# Clean up once worker exits busy loop
if worker is not None:
worker.shutdown()
class ResponseStatus(Enum):
SUCCESS = auto()
FAILURE = auto()
def enqueue_output(self, output: Any):
"""Prepares output from the worker and enqueues it to the
worker_response_mq. If the output is an Exception, it is
converted to a FAILURE response.
"""
if isinstance(output, AsyncModelRunnerOutput):
output = output.get_output()
if isinstance(output, Exception):
result = (WorkerProc.ResponseStatus.FAILURE, str(output))
else:
result = (WorkerProc.ResponseStatus.SUCCESS, output)
if (response_mq := self.worker_response_mq) is not None:
response_mq.enqueue(result)
def handle_output(self, output: Any):
"""Handles output from the worker. If async scheduling is enabled,
it is passed to the async_output_busy_loop thread. Otherwise, it is
enqueued directly to the worker_response_mq.
"""
if self.use_async_scheduling:
self.async_output_queue.put(output)
else:
self.enqueue_output(output)
def async_output_busy_loop(self):
"""Entrypoint for the thread which handles outputs asynchronously."""
while True:
output = self.async_output_queue.get()
self.enqueue_output(output)
def worker_busy_loop(self):
"""Main busy loop for Multiprocessing Workers"""
assert self.rpc_broadcast_mq is not None
while True:
method, args, kwargs, output_rank = self.rpc_broadcast_mq.dequeue(
indefinite=True
)
try:
if isinstance(method, str):
func = getattr(self.worker, method)
elif isinstance(method, bytes):
func = partial(cloudpickle.loads(method), self.worker)
output = func(*args, **kwargs)
except Exception as e:
# Notes have been introduced in python 3.11
if hasattr(e, "add_note"):
e.add_note(traceback.format_exc())
logger.exception("WorkerProc hit an exception.")
# exception might not be serializable, so we convert it to
# string, only for logging purpose.
if output_rank is None or self.rank == output_rank:
self.handle_output(e)
continue
if output_rank is None or self.rank == output_rank:
self.handle_output(output)
@staticmethod
def setup_proc_title_and_log_prefix(enable_ep: bool) -> None:
# Check if parallel groups are initialized first
if not model_parallel_is_initialized():
# Parallel groups not yet initialized, use default process name
set_process_title(name="Worker")
decorate_logs("Worker")
return
dp_size = get_dp_group().world_size
dp_rank = get_dp_group().rank_in_group
pp_size = get_pp_group().world_size
pp_rank = get_pp_group().rank_in_group
pcp_size = get_pcp_group().world_size
pcp_rank = get_pcp_group().rank_in_group
tp_size = get_tp_group().world_size
tp_rank = get_tp_group().rank_in_group
dcp_size = get_dcp_group().world_size
dcp_rank = get_dcp_group().rank_in_group
process_name = "Worker"
if dp_size > 1:
process_name += f"_DP{dp_rank}"
if pp_size > 1:
process_name += f"_PP{pp_rank}"
if pcp_size > 1:
process_name += f"_PCP{pcp_rank}"
if tp_size > 1:
process_name += f"_TP{tp_rank}"
if dcp_size > 1:
process_name += f"_DCP{dcp_rank}"
if enable_ep:
ep_rank = get_ep_group().rank_in_group
process_name += f"_EP{ep_rank}"
set_process_title(name=process_name)
decorate_logs(process_name)