Skip to content

vllm.beam_search

BeamSearchOutput dataclass

The output of beam search. It contains the list of the best beam search sequences. The length of the list is equal to the beam width.

Source code in vllm/beam_search.py
@dataclass
class BeamSearchOutput:
    """The output of beam search.
    It contains the list of the best beam search sequences.
    The length of the list is equal to the beam width.
    """

    sequences: list[BeamSearchSequence]

BeamSearchSequence dataclass

A sequence for beam search. It keeps track of the tokens and the log probability of the sequence. The text field is optional and will only be filled when the sequence is about to be returned to the user.

Source code in vllm/beam_search.py
@dataclass
class BeamSearchSequence:
    """A sequence for beam search.
    It keeps track of the tokens and the log probability of the sequence.
    The text field is optional and will only be filled when the sequence is
    about to be returned to the user.
    """

    orig_prompt: TokenInputs | MultiModalInputs | EncoderDecoderInputs

    # NOTE: Tokens represents decoder tokens in the encoder / decoder case
    tokens: list[int]
    logprobs: list[dict[int, Logprob]]
    lora_request: LoRARequest | None = None
    cum_logprob: float = 0.0
    text: str | None = None
    finish_reason: str | None = None
    stop_reason: int | str | None = None

    def get_prompt(self):
        prompt = self.orig_prompt

        if prompt["type"] == "enc_dec":
            return self._build_encoder_decoder_inputs(prompt)

        # Handle decoder-only inputs
        prompt_text = prompt.get("prompt")
        cache_salt = prompt.get("cache_salt")

        if prompt["type"] == "token":
            return token_inputs(
                self.tokens,
                prompt=prompt_text,
                cache_salt=cache_salt,
            )

        return mm_inputs(
            prompt_token_ids=self.tokens,
            mm_kwargs=prompt["mm_kwargs"],
            mm_hashes=prompt["mm_hashes"],
            mm_placeholders=prompt["mm_placeholders"],
            prompt=prompt_text,
            cache_salt=cache_salt,
        )

    def _build_encoder_decoder_inputs(
        self, prompt: EncoderDecoderInputs
    ) -> EncoderDecoderInputs:
        """Rebuild the encoder-decoder inputs with the current beam search
        sequence's tokens.

        FIXME (alex) - the encoder multimodal cache is not properly wired up
        yet, which means that currently we are running the encoder on every
        new beam because num_computed_tokens is 0 on each new request. This
        will be fixed once the cache is correctly implemented.
        """
        dec_prompt = prompt["decoder_prompt"]

        # Rebuild decoder prompt with updated tokens,
        # but keep everything else the same.
        new_dec_prompt: DecoderInputs
        if dec_prompt["type"] == "multimodal":
            new_dec_prompt = mm_inputs(
                self.tokens,
                mm_kwargs=dec_prompt["mm_kwargs"],
                mm_hashes=dec_prompt["mm_hashes"],
                mm_placeholders=dec_prompt["mm_placeholders"],
                prompt=dec_prompt.get("prompt"),
                cache_salt=dec_prompt.get("cache_salt"),
            )
        else:
            new_dec_prompt = token_inputs(
                self.tokens,
                prompt=dec_prompt.get("prompt"),
                cache_salt=dec_prompt.get("cache_salt"),
            )

        return EncoderDecoderInputs(
            type="enc_dec",
            encoder_prompt=prompt["encoder_prompt"],
            decoder_prompt=new_dec_prompt,
        )

_build_encoder_decoder_inputs

_build_encoder_decoder_inputs(
    prompt: EncoderDecoderInputs,
) -> EncoderDecoderInputs

Rebuild the encoder-decoder inputs with the current beam search sequence's tokens.

FIXME (alex) - the encoder multimodal cache is not properly wired up yet, which means that currently we are running the encoder on every new beam because num_computed_tokens is 0 on each new request. This will be fixed once the cache is correctly implemented.

Source code in vllm/beam_search.py
def _build_encoder_decoder_inputs(
    self, prompt: EncoderDecoderInputs
) -> EncoderDecoderInputs:
    """Rebuild the encoder-decoder inputs with the current beam search
    sequence's tokens.

    FIXME (alex) - the encoder multimodal cache is not properly wired up
    yet, which means that currently we are running the encoder on every
    new beam because num_computed_tokens is 0 on each new request. This
    will be fixed once the cache is correctly implemented.
    """
    dec_prompt = prompt["decoder_prompt"]

    # Rebuild decoder prompt with updated tokens,
    # but keep everything else the same.
    new_dec_prompt: DecoderInputs
    if dec_prompt["type"] == "multimodal":
        new_dec_prompt = mm_inputs(
            self.tokens,
            mm_kwargs=dec_prompt["mm_kwargs"],
            mm_hashes=dec_prompt["mm_hashes"],
            mm_placeholders=dec_prompt["mm_placeholders"],
            prompt=dec_prompt.get("prompt"),
            cache_salt=dec_prompt.get("cache_salt"),
        )
    else:
        new_dec_prompt = token_inputs(
            self.tokens,
            prompt=dec_prompt.get("prompt"),
            cache_salt=dec_prompt.get("cache_salt"),
        )

    return EncoderDecoderInputs(
        type="enc_dec",
        encoder_prompt=prompt["encoder_prompt"],
        decoder_prompt=new_dec_prompt,
    )

get_beam_search_score

get_beam_search_score(
    tokens: list[int],
    cumulative_logprob: float,
    eos_token_id: int,
    length_penalty: float = 1.0,
) -> float

Calculate the beam search score with length penalty.

Adapted from

https://github.com/huggingface/transformers/blob/ccb92be23def445f2afdea94c31286f84b89eb5b/src/transformers/generation/beam_search.py#L938

Source code in vllm/beam_search.py
def get_beam_search_score(
    tokens: list[int],
    cumulative_logprob: float,
    eos_token_id: int,
    length_penalty: float = 1.0,
) -> float:
    """Calculate the beam search score with length penalty.

    Adapted from

    https://github.com/huggingface/transformers/blob/ccb92be23def445f2afdea94c31286f84b89eb5b/src/transformers/generation/beam_search.py#L938
    """
    seq_len = len(tokens)
    if tokens[-1] == eos_token_id:
        seq_len -= 1

    return cumulative_logprob / (seq_len**length_penalty)