Source code for jaxgarden.models.generation_utils

from functools import partial
from typing import TYPE_CHECKING, Any, TypeVar, cast

import jax
import jax.numpy as jnp
from flax import nnx

from jaxgarden.models.base import BaseModel

# TYPE_CHECKING to avoid circular imports at runtime
if TYPE_CHECKING:
    # Define a TypeVar for documenting the intended usage context,
    # i.e., classes using this mixin should inherit from BaseModel.
    T_BaseModel = TypeVar("T_BaseModel", bound="BaseModel")
    # Define a base class for type checking
    _Base = BaseModel
else:
    # Runtime doesn't strictly need the bound, but define T anyway
    T_BaseModel = TypeVar("T_BaseModel")
    # Use object at runtime
    _Base = object

# Constants for numerical stability
EPSILON = 1e-9


def temperature_scale(logits: jnp.ndarray, temperature: float) -> jnp.ndarray:
    """Scales logits by temperature.

    Args:
        logits: Logits to scale. Shape: (..., vocab_size)
        temperature: Temperature value. Higher values make the distribution flatter (more random),
                     lower values make it peakier (more deterministic). Must be positive.

    Returns:
        Scaled logits.
    """
    if temperature <= 0:
        raise ValueError(f"Temperature must be positive, got {temperature}")
    # Prevent division by zero by adding a small epsilon if temperature is zero
    safe_temperature = max(temperature, EPSILON)
    return logits / safe_temperature


def top_k_logits(logits: jnp.ndarray, k: int) -> jnp.ndarray:
    """Masks logits outside the top k values.

    Sets logits not in the top k to negative infinity.

    Args:
        logits: Logits to filter. Shape: (..., vocab_size)
        k: Number of top logits to keep.

    Returns:
        Filtered logits.
    """
    if k <= 0:
        # If k is 0 or negative, mask all logits
        return jnp.full_like(logits, -jnp.inf)

    # Ensure k is not larger than the vocabulary size
    k = min(k, logits.shape[-1])

    # Get top-k values
    top_k_values = jax.lax.top_k(logits, k=k)[0]
    kth_value = top_k_values[..., -1:]

    # Create a mask where logits >= kth_value are True
    mask = logits >= kth_value

    # Set logits below the threshold to -inf
    return jnp.where(mask, logits, -jnp.inf)


def top_p_logits(logits: jnp.ndarray, p: float) -> jnp.ndarray:
    """Filter logits using nucleus (top-p) sampling.

    Args:
        logits: Shape (..., vocab_size)
        p: Probability threshold (0 < p <= 1)

    Returns:
        Filtered logits with -inf for tokens outside the top-p nucleus
    """
    if not 0 < p <= 1.0:
        raise ValueError(f"p must be in (0, 1], got {p}")
    if p == 1.0:
        return logits

    # Convert to probabilities
    probs = nnx.softmax(logits, axis=-1)

    # Sort probabilities in descending order
    sorted_probs = jnp.sort(probs, axis=-1)[..., ::-1]

    # Calculate cumulative probabilities and create mask
    cumulative_probs = jnp.cumsum(sorted_probs, axis=-1)
    sorted_mask = cumulative_probs <= p

    # Always include at least the top token
    sorted_mask = sorted_mask.at[..., 0].set(True)

    # Find the minimum probability within the nucleus
    threshold = jnp.min(
        jnp.where(sorted_mask, sorted_probs, jnp.ones_like(sorted_probs)), axis=-1, keepdims=True
    )

    # Apply threshold to original probabilities
    # Keep tokens whose probability is >= threshold
    mask = probs >= threshold

    # Apply mask to logits
    return jnp.where(mask, logits, -jnp.inf)


def min_p_logits(logits: jnp.ndarray, p: float) -> jnp.ndarray:
    """Masks logits below a probability threshold derived from the max probability (min_p sampling).
    Filters out tokens with probability less than p * max_probability.

    Args:
        logits: Logits to filter. Shape: (..., vocab_size)
        p: Probability threshold factor (0 < p <= 1).

    Returns:
        Filtered logits.
    """
    if not 0 < p <= 1.0:
        raise ValueError(f"p must be in (0, 1], got {p}")

    probs = nnx.softmax(logits, axis=-1)
    max_prob = jnp.max(probs, axis=-1, keepdims=True)
    threshold = max_prob * p

    # Identify indices corresponding to max probability
    max_prob_indices = probs >= (max_prob - EPSILON)

    if p == 1.0:
        # When p=1.0, keep just the max probability tokens
        mask = ~max_prob_indices
    else:
        # Otherwise, keep max prob tokens and tokens above the threshold
        mask_below_threshold = probs < threshold
        mask = jnp.where(max_prob_indices, False, mask_below_threshold)

    # Apply the mask to the original logits
    return jnp.where(mask, -jnp.inf, logits)


def sample_logits(
    logits: jnp.ndarray,
    rng_key: jax.Array,
    temperature: float = 1.0,
    top_k: int | None = None,
    top_p: float | None = None,
    min_p: float | None = None,
    do_sample: bool = True,
) -> jnp.ndarray:
    """Samples a token index from logits using specified filtering and temperature.

    Applies filtering methods (top_k, top_p, min_p) and temperature scaling,
    then samples from the resulting distribution or takes the argmax.

    Args:
        logits: Raw logits from the model. Shape: (..., vocab_size)
        rng_key: JAX PRNG key for sampling.
        temperature: Temperature scaling factor.
        top_k: If set, keep only top k logits.
        top_p: If set, keep smallest set of logits whose cumulative probability exceeds p.
        min_p: If set, keep logits with probability >= max_prob * p.
        do_sample: If True, sample using categorical distribution.
                    If False, take argmax (greedy decoding).

    Returns:
        Sampled token indices. Shape: (...)
    """
    if not do_sample:
        # Greedy decoding
        return jnp.argmax(logits, axis=-1)

    # 1. Apply temperature scaling
    if temperature != 1.0 and temperature > 0:
        scaled_logits = temperature_scale(logits, temperature)
    else:
        scaled_logits = logits

    # Store the scaled logits as the potential fallback
    logits_for_fallback = scaled_logits

    # 2. Apply filtering
    filtered_logits = scaled_logits
    # Apply filtering in a specific order (min_p -> top_k -> top_p is one common order)
    # Note: The order can matter. Min_p focuses on dynamic range,
    # while top_k/top_p on absolute ranks/mass.
    if min_p is not None and 0 < min_p < 1.0:
        filtered_logits = min_p_logits(filtered_logits, min_p)
    if top_k is not None and top_k > 0:
        filtered_logits = top_k_logits(filtered_logits, top_k)
    if top_p is not None and 0 < top_p < 1.0:  # top_p=1 means no filtering
        filtered_logits = top_p_logits(filtered_logits, top_p)

    # 3. Sample or take argmax, handling the edge case for sampling

    all_filtered_infinite = jnp.all(filtered_logits == -jnp.inf, axis=-1, keepdims=True)

    # Determine the logits to actually sample from:
    # Use the fallback (scaled, unfiltered) if all filtered are -inf
    final_logits_for_sampling = jnp.where(
        all_filtered_infinite,
        logits_for_fallback,  # Fallback to pre-filter (but post-temp) logits
        filtered_logits,  # Otherwise, use the filtered logits
    )

    # Sample using the chosen logits
    sampled_indices = jax.random.categorical(rng_key, final_logits_for_sampling, axis=-1)

    return sampled_indices


def create_causal_mask(seq_len: int) -> jnp.ndarray:
    """Creates a causal attention mask for a given sequence length.

    Args:
        seq_len: The length of the sequence.

    Returns:
        A causal attention mask of shape [seq_len, seq_len].
    """
    mask = jnp.tril(jnp.ones((seq_len, seq_len)))
    return mask


[docs] class GenerationMixin(_Base): """Mixin that adds text generation capabilities, including sampling with temperature, top-k, top-p, and min-probability filtering, for CausalLMs.""" def _generate_scan_logic( self: "GenerationMixin", # self is needed for model call initial_input_ids: jnp.ndarray, initial_finished_sequences: jnp.ndarray, initial_rng: jax.Array, initial_seq_len: int, # --- Static arguments (used by JIT, must be passed regardless) --- max_length: int, temperature: float, top_k: int | None, top_p: float | None, min_p: float | None, do_sample: bool, pad_token_id: int, eos_token_id: int | None, ) -> jnp.ndarray: """The core autoregressive generation logic using lax.scan. This function itself is NOT jitted here.""" batch_size = initial_input_ids.shape[0] output_ids = jnp.full((batch_size, max_length), pad_token_id, dtype=initial_input_ids.dtype) output_ids = output_ids.at[:, :initial_seq_len].set(initial_input_ids) def scan_step(carry: dict, _: Any) -> tuple[dict, None]: current_output_ids = carry["output_ids"] current_length = carry["current_length"] step_rng = carry["rng"] current_finished = carry["finished"] next_rng = step_rng sampling_rng = step_rng if do_sample: # Relies on do_sample being static *when jitted* sampling_rng, next_rng = jax.random.split(step_rng) # This mask tells the model which tokens are valid. attention_mask = (jnp.arange(max_length) < current_length).astype(jnp.int32)[None, :] # Call the model, passing the attention mask # Assume the model returns a dictionary with 'logits' # Also assume the model should run deterministically during generation logits = self( # type: ignore[operator] input_ids=current_output_ids, attention_mask=attention_mask, deterministic=True, # Generation should be deterministic (no dropout) ) # Get logits for the *next* token prediction (at index current_length - 1) next_token_logits = jax.lax.dynamic_slice_in_dim( logits, current_length - 1, 1, axis=1 ).squeeze(axis=1) # Sample the next token next_token = sample_logits( logits=next_token_logits, rng_key=sampling_rng, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, do_sample=do_sample, ) next_token = next_token.astype(current_output_ids.dtype) # Determine the token to actually write (handling EOS and padding) output_token = next_token # Default next_finished = current_finished if eos_token_id is not None: # Relies on eos_token_id being static *when jitted* newly_finished = (next_token == eos_token_id) & (~current_finished) next_finished = current_finished | newly_finished # If already finished, write pad_token_id, otherwise write the sampled token output_token = jnp.where(current_finished, pad_token_id, next_token) # Update the output sequence updated_output_ids = jax.lax.dynamic_update_slice_in_dim( current_output_ids, output_token[:, None], current_length, axis=1 ) # Prepare carry for the next step next_carry = { "output_ids": updated_output_ids, "current_length": current_length + 1, "rng": next_rng, "finished": next_finished, } return next_carry, None initial_carry = { "output_ids": output_ids, "current_length": jnp.array(initial_seq_len), "rng": initial_rng, "finished": initial_finished_sequences, } num_steps_to_generate = max_length - initial_seq_len # Run the scan only if needed if num_steps_to_generate > 0: final_carry, _ = jax.lax.scan( scan_step, initial_carry, None, length=num_steps_to_generate ) final_output_ids = final_carry["output_ids"] else: # If no steps needed (initial_seq_len == max_length), return initial output_ids final_output_ids = output_ids return cast(jnp.ndarray, final_output_ids) # Define the compiled version of the scan logic # This uses partial to pre-apply jax.jit with static arguments # Note: Compiling happens when this method definition is executed. _generate_compiled = partial( jax.jit, # Specify arguments that control the computation graph structure static_argnames=( "self", # Need self for model call inside scan_step "max_length", "temperature", "top_k", "top_p", "min_p", "do_sample", "pad_token_id", "eos_token_id", "initial_seq_len", ), )(_generate_scan_logic) # Apply JIT to the core logic function
[docs] def generate( self: "GenerationMixin", input_ids: jnp.ndarray, attention_mask: jnp.ndarray | None = None, max_length: int = 20, temperature: float = 1.0, top_k: int | None = None, top_p: float | None = None, min_p: float | None = None, do_sample: bool = True, pad_token_id: int | None = None, eos_token_id: int | None = None, rng: jax.Array | None = None, use_jit: bool = False, ) -> jnp.ndarray: """Generate tokens autoregressively with various sampling methods. Args: input_ids: Initial token IDs of shape [batch_size, seq_len]. max_length: Maximum length for generated sequences. temperature: Temperature for sampling. top_k: If specified, only sample from the top-k logits. top_p: If specified, only sample from the smallest set of logits whose cumulative probability exceeds p. min_p: If specified, only consider logits with prob >= min_p * prob(max logit). do_sample: If True, use sampling; otherwise use greedy/beam search. pad_token_id: Token ID to use for padding. eos_token_id: Token ID that signals the end of generation. rng: Optional PRNG key for sampling. use_jit: If True, use jax.jit to compile the generate function. Returns: Generated token IDs of shape [batch_size, max_length]. """ if not isinstance(max_length, int) or max_length <= 0: raise ValueError(f"max_length must be a positive integer, got {max_length}") if not isinstance(temperature, (float, int)) or temperature <= 0: # noqa: UP038 raise ValueError(f"temperature must be positive, got {temperature}") if top_k is not None and (not isinstance(top_k, int) or top_k <= 0): raise ValueError(f"top_k must be a positive integer, got {top_k}") if top_p is not None and (not isinstance(top_p, float) or not 0 < top_p <= 1.0): raise ValueError(f"top_p must be in (0, 1], got {top_p}") _top_p = top_p if top_p != 1.0 else None # Handle p=1.0 case internally if min_p is not None and (not isinstance(min_p, float) or not 0 < min_p <= 1.0): raise ValueError(f"min_p must be in (0, 1], got {min_p}") if input_ids.ndim != 2: raise ValueError( f"input_ids must be 2D [batch_size, seq_len], got shape {input_ids.shape}" ) # Handle RNG key _rng = rng if do_sample: if _rng is None: print("Warning: No RNG key provided for sampling, using default key 0.") _rng = jax.random.PRNGKey(0) # Use seed 0 # Ensure rng is a JAX key if isinstance(_rng, int): _rng = jax.random.PRNGKey(_rng) elif not isinstance(_rng, jax.Array): raise ValueError(f"Invalid rng provided: {_rng}. Expected JAX PRNGKey or seed.") elif _rng is None: # Provide a dummy key if not sampling and None was passed _rng = jax.random.PRNGKey(0) elif isinstance(_rng, int): # Ensure key even if not sampling but seed provided _rng = jax.random.PRNGKey(_rng) # Resolve pad_token_id _pad_token_id = pad_token_id if _pad_token_id is None: # Safely access config attribute config = getattr(self, "config", None) _pad_token_id = getattr(config, "pad_token_id", 0) if config else 0 if not isinstance(_pad_token_id, int): raise ValueError(f"pad_token_id must be an integer, got {_pad_token_id}") # Resolve eos_token_id _eos_token_id = eos_token_id if _eos_token_id is not None and not isinstance(_eos_token_id, int): raise ValueError(f"eos_token_id must be an integer or None, got {_eos_token_id}") # Get initial sequence length and batch size batch_size, initial_seq_len = input_ids.shape # Handle cases where input is already long enough if initial_seq_len >= max_length: print(f"""Warning: Initial sequence length ({initial_seq_len}) \ is >= max_length ({max_length}). \ Returning truncated input.""") return input_ids[:, :max_length] # Track whether each sequence is finished finished_sequences = jnp.zeros((batch_size,), dtype=jnp.bool_) if _eos_token_id is not None: # Check if the *last* token of the input is EOS finished_sequences = jnp.where( initial_seq_len > 0, input_ids[:, -1] == _eos_token_id, jnp.zeros_like(finished_sequences), # Ensure correct shape if seq_len is 0 ) # --- Conditionally Call Jitted or Non-Jitted Core Logic --- # common_args = { # "initial_input_ids": input_ids, # "initial_finished_sequences": finished_sequences, # "initial_rng": _rng, # "initial_seq_len": initial_seq_len, # Pass as static arg for JIT # "max_length": max_length, # "temperature": float(temperature), # Ensure float # "top_k": top_k, # "top_p": _top_p, # Use resolved top_p # "min_p": min_p, # "do_sample": do_sample, # "pad_token_id": _pad_token_id, # Use resolved pad id # "eos_token_id": _eos_token_id, # Use resolved eos id # } if use_jit: # Call the pre-compiled method final_output_ids = self._generate_compiled( initial_input_ids=input_ids, initial_finished_sequences=finished_sequences, initial_rng=_rng, initial_seq_len=initial_seq_len, max_length=max_length, temperature=float(temperature), top_k=top_k, top_p=_top_p, min_p=min_p, do_sample=do_sample, pad_token_id=_pad_token_id, eos_token_id=_eos_token_id, ) else: # Call the raw logic method directly final_output_ids = self._generate_scan_logic( initial_input_ids=input_ids, initial_finished_sequences=finished_sequences, initial_rng=_rng, initial_seq_len=initial_seq_len, max_length=max_length, temperature=float(temperature), top_k=top_k, top_p=_top_p, min_p=min_p, do_sample=do_sample, pad_token_id=_pad_token_id, eos_token_id=_eos_token_id, ) return cast(jnp.ndarray, final_output_ids)