API Documentation
Attention Modules
Attention modules for JAXgarden.
- class jaxgarden.attention.MultiHeadAttention(*args: Any, **kwargs: Any)[source]
Bases:
MultiHeadAttention
Multi-head attention with support for Flash Attention.
This class extends Flax NNX’s MultiHeadAttention to support Flash Attention through JAX’s dot_product_attention implementation parameter.
Example usage:
```python import jax import jax.numpy as jnp import flax.nnx as nnx from jax_layers.attention import MultiHeadAttention
# Create a MultiHeadAttention module with Flash Attention support attention = MultiHeadAttention(
num_heads=8, in_features=512, implementation=”cudnn”, # Use cuDNN’s Flash Attention if available rngs=nnx.Rngs(0),
)
# Initialize parameters key = jax.random.PRNGKey(0) x = jax.random.normal(key, (2, 128, 512)) # (batch, seq_length, hidden_dim)
# Create a causal attention mask mask = jnp.tril(jnp.ones((2, 1, 128, 128))) # (batch, 1, q_len, kv_len)
# Apply the model output = attention(x, mask=mask) ```
- __init__(num_heads: int, in_features: int, qkv_features: int | None = None, out_features: int | None = None, *, dtype: ~numpy.dtype | None = None, param_dtype: ~numpy.dtype = <class 'jax.numpy.float32'>, broadcast_dropout: bool = True, dropout_rate: float = 0.0, deterministic: bool | None = None, precision: ~jax._src.lax.lax.Precision | str | None = None, kernel_init: ~collections.abc.Callable = <function variance_scaling.<locals>.init>, out_kernel_init: ~collections.abc.Callable | None = None, bias_init: ~collections.abc.Callable = <function zeros>, out_bias_init: ~collections.abc.Callable | None = None, use_bias: bool = True, attention_fn: ~collections.abc.Callable | None = None, decode: bool | None = None, normalize_qk: bool = False, qkv_dot_general: ~collections.abc.Callable | None = None, out_dot_general: ~collections.abc.Callable | None = None, qkv_dot_general_cls: type | None = None, out_dot_general_cls: type | None = None, implementation: ~typing.Literal['xla', 'cudnn', 'flash'] | None = None, rngs: ~flax.nnx.rnglib.Rngs)[source]
Initialize the MultiHeadAttention module.
- Parameters:
num_heads – number of attention heads.
in_features – int or tuple with number of input features.
qkv_features – dimension of the key, query, and value.
out_features – dimension of the last projection.
dtype – the dtype of the computation.
param_dtype – the dtype passed to parameter initializers.
broadcast_dropout – bool: use a broadcasted dropout along batch dims.
dropout_rate – dropout rate.
deterministic – if false, the attention weight is masked randomly using dropout.
precision – numerical precision of the computation.
kernel_init – initializer for the kernel of the Dense layers.
out_kernel_init – initializer for the kernel of the output Dense layer.
bias_init – initializer for the bias of the Dense layers.
out_bias_init – initializer for the bias of the output Dense layer.
use_bias – bool: whether pointwise QKVO dense transforms use bias.
attention_fn – dot_product_attention or compatible function.
decode – whether to prepare and use an autoregressive cache.
normalize_qk – should QK normalization be applied.
qkv_dot_general – dot_general function for QKV projection.
out_dot_general – dot_general function for output projection.
qkv_dot_general_cls – dot_general class for QKV projection.
out_dot_general_cls – dot_general class for output projection.
implementation – which implementation to use for attention. Options are: - “xla”: Use XLA’s default implementation - “cudnn”: Use cuDNN’s Flash Attention implementation (if available) - “flash”: Alias for “cudnn” - None: Automatically select the best available implementation
rngs – random number generator keys.
Functional Interfaces
Functional implementations for JAXgarden.
- jaxgarden.functional.dot_product_attention(query: Array, key: Array, value: Array, bias: Array | None = None, mask: Array | None = None, broadcast_dropout: bool = True, dropout_rng: Array | None = None, dropout_rate: float = 0.0, deterministic: bool = False, dtype: dtype | None = None, precision: Precision | str | None = None, implementation: Literal['xla', 'cudnn', 'flash'] | None = None, module: object | None = None) Array [source]
Computes dot-product attention with optional Flash Attention support.
This function provides a wrapper around JAX’s dot_product_attention with the option to use Flash Attention when available. It follows the Flax NNX interface while allowing the use of different implementations through the implementation parameter.
- Parameters:
query – queries for calculating attention with shape of [batch…, q_length, num_heads, qk_depth_per_head].
key – keys for calculating attention with shape of [batch…, kv_length, num_heads, qk_depth_per_head].
value – values to be used in attention with shape of [batch…, kv_length, num_heads, v_depth_per_head].
bias – bias for the attention weights. This should be broadcastable to the shape [batch…, num_heads, q_length, kv_length].
mask – mask for the attention weights. This should be broadcastable to the shape [batch…, num_heads, q_length, kv_length].
broadcast_dropout – bool: use a broadcasted dropout along batch dims.
dropout_rng – JAX PRNGKey: to be used for dropout.
dropout_rate – dropout rate.
deterministic – bool, deterministic or not (to apply dropout).
dtype – the dtype of the computation (default: infer from inputs).
precision – numerical precision of the computation.
implementation – which implementation to use. Options are: - “xla”: Use XLA’s default implementation - “cudnn”: Use cuDNN’s Flash Attention implementation (if available) - “flash”: Alias for “cudnn” - None: Automatically select the best available implementation
module – the Module that will sow the attention weights.
- Returns:
Output of shape [batch…, q_length, num_heads, v_depth_per_head].
Models
- class jaxgarden.models.BaseConfig(seed: int = 42, log_level: str = 'info', extra: dict[str, ~typing.Any] = <factory>)[source]
Bases:
object
Base configuration for all the models implemented in the JAXgarden library.
Each model implemented in JAXgarden should subclass this class for configuration management.
- class jaxgarden.models.BaseModel(*args: Any, **kwargs: Any)[source]
Bases:
Module
Base class for all the models implemented in the JAXgarden library.
- __init__(config: ~jaxgarden.models.base.BaseConfig, *, dtype: ~numpy.dtype | None = None, param_dtype: ~numpy.dtype = <class 'jax.numpy.float32'>, precision: ~jax._src.lax.lax.Precision | str | None = None, rngs: ~flax.nnx.rnglib.Rngs)[source]
Initialize the model.
Args:
config: config class for this model. dtype: Data type in which computation is performed.
- param_dtype: Data type in which params are stored.
precision: Numerical precision. rngs: Random number generators for param initialization etc.
- convert_weights_from_hf(state: State, weights: Iterator[tuple[Any, Any]]) None [source]
Convert weights from Hugging Face Hub to the model’s state.
This method should be implemented in downstream classes to support conversion from HuggingFace format.
- static download_from_hf(repo_id: str, local_dir: str, token: str | None = None, force_download: bool = False) None [source]
Downloads the model from the Hugging Face Hub.
- Parameters:
repo_id – The repository ID of the model to download.
local_dir – The local directory to save the model to.
- from_hf(model_repo_or_id: str, token: str | None = None, force_download: bool = False, save_in_orbax: bool = True, remove_hf_after_conversion: bool = True) None [source]
Downloads the model from the Hugging Face Hub and returns a new instance of the model.
- It can also save the converted weights in an Orbax checkpoint
and removes the original HF checkpoint after conversion.
- Parameters:
model_repo_or_id – The repository ID or name of the model to download.
token – The token to use for authentication with the Hugging Face Hub.
save_in_orbax – Whether to save the converted weights in an Orbax checkpoint.
remove_hf_after_conversion – Whether to remove the downloaded HuggingFace checkpoint after conversion.
- static iter_safetensors(path_to_model_weights: str) Iterator[tuple[Any, Any]] [source]
Helper function to lazily load params from safetensors file.
Use this static method to iterate over weights for conversion tasks.
- Parameters:
model_path_to_params – Path to directory containing .safetensors files.
- load(path: str) Module [source]
Loads the model state from a directory.
- Parameters:
path – The directory path to load the model state from.
- class jaxgarden.models.GenerationMixin[source]
Bases:
object
Mixin that adds text generation capabilities, including sampling with temperature, top-k, top-p, and min-probability filtering, for CausalLMs.
- generate(input_ids: Array, attention_mask: Array | None = None, max_length: int = 20, temperature: float = 1.0, top_k: int | None = None, top_p: float | None = None, min_p: float | None = None, do_sample: bool = True, pad_token_id: int | None = None, eos_token_id: int | None = None, rng: Array | None = None, use_jit: bool = False) Array [source]
Generate tokens autoregressively with various sampling methods.
- Parameters:
input_ids – Initial token IDs of shape [batch_size, seq_len].
max_length – Maximum length for generated sequences.
temperature – Temperature for sampling.
top_k – If specified, only sample from the top-k logits.
top_p – If specified, only sample from the smallest set of logits whose cumulative probability exceeds p.
min_p – If specified, only consider logits with prob >= min_p * prob(max logit).
do_sample – If True, use sampling; otherwise use greedy/beam search.
pad_token_id – Token ID to use for padding.
eos_token_id – Token ID that signals the end of generation.
rng – Optional PRNG key for sampling.
use_jit – If True, use jax.jit to compile the generate function.
- Returns:
Generated token IDs of shape [batch_size, max_length].
- class jaxgarden.models.LlamaAttention(*args: Any, **kwargs: Any)[source]
Bases:
Module
Multi-headed attention with support for Group Query Attention (GQA).
This implements the LLama attention mechanism with rotary position embeddings (RoPE) and support for fewer key-value heads than query heads (GQA).
- q_proj
Linear projection for queries
- k_proj
Linear projection for keys
- v_proj
Linear projection for values
- o_proj
Linear projection for output
- rotary_emb
Rotary position embeddings
- head_dim
Dimension of each attention head
- n_heads
Number of attention heads
- n_kv_heads
Number of key/value heads
- __call__(x: Array, position_ids: Array, attention_mask: Array) Array [source]
Apply self-attention using queries, keys, and values derived from input x.
- Parameters:
x – Input tensor of shape [batch_size, seq_len, dim]
position_ids – Position indices of shape [batch_size, seq_len]
attention_mask – Additive attention mask of shape (batch_size, 1, q_len, kv_len) or a shape that can be broadcasted to this shape.
- Returns:
Output tensor of shape [batch_size, seq_len, dim]
- __init__(layer_idx: int, dim: int, n_heads: int, n_kv_heads: int, head_dim: int, rope_theta: float, *, rngs: Rngs)[source]
Initialize attention module.
- Parameters:
layer_idx – Index of the layer
dim – Size of hidden states
n_heads – Number of attention heads
n_kv_heads – Number of key/value heads
head_dim – Dimension of each attention head
rope_theta – Base for rotary position embeddings
rngs – PRNG key collection
- apply_rotary_pos_emb(q: Array, k: Array, cos: Array, sin: Array, unsqueeze_dim: int = 1) tuple[Array, Array] [source]
Apply rotary position embeddings to query and key tensors.
- Parameters:
q – Query tensor
k – Key tensor
cos – Cosine component of rotary embeddings
sin – Sine component of rotary embeddings
unsqueeze_dim – Dimension to unsqueeze cosine and sine components
- Returns:
Tuple of (rotated_q, rotated_k)
- repeat_kv(hidden_states: Array, n_repeat: int) Array [source]
Repeat key/value heads to match the number of query heads.
When using GQA, we need to repeat each key/value head to match the number of query heads.
- Parameters:
hidden_states – Key or value tensor of shape [batch, n_kv_heads, seq_len, head_dim]
n_repeat – Number of times to repeat each key/value head
- Returns:
Tensor with repeated key/value heads
- class jaxgarden.models.LlamaConfig(seed: int = 42, log_level: str = 'info', extra: dict[str, ~typing.Any] = <factory>, dim: int = 2048, n_layers: int = 16, n_heads: int = 32, n_kv_heads: int = 8, head_dim: int = 64, intermediate_size: int = 14336, vocab_size: int = 128256, multiple_of: int = 256, norm_eps: float = 1e-05, rope_theta: float = 500000.0)[source]
Bases:
BaseConfig
Configuration for LLama model.
This configuration class extends BaseConfig and contains all the parameters required to initialize a LLama model. It includes settings for model architecture, attention mechanisms, and other hyperparameters.
- class jaxgarden.models.LlamaForCausalLM(*args: Any, **kwargs: Any)[source]
Bases:
BaseModel
,GenerationMixin
LLama model for causal language modeling.
This implements the full LLama model for generating text. It consists of token embeddings, transformer layers, and language modeling head.
- token_embed
Token embedding layer
- layers
List of transformer blocks
- lm_head
Linear layer for language modeling
- norm
Final layer normalization
- __call__(input_ids: Array, attention_mask: Array | None = None, deterministic: bool = True) Array [source]
Forward pass of the LLama model.
- Parameters:
input_ids – Input token ids of shape [batch_size, seq_len]
attention_mask – Boolean attention mask of shape [batch_size, seq_len]
- Returns:
Logits for next token prediction of shape [batch_size, seq_len, vocab_size]
- __init__(config: ~jaxgarden.models.llama.LlamaConfig, *, dtype: ~numpy.dtype | None = None, param_dtype: ~numpy.dtype = <class 'jax.numpy.float32'>, precision: ~jax._src.lax.lax.Precision | str | None = None, rngs: ~flax.nnx.rnglib.Rngs)[source]
Initialize LlamaForCausalLM.
- Parameters:
config – Model configuration
dtype – Data type for computation
param_dtype – Data type for parameters
precision – Precision for matrix multiplication
rngs – PRNG key collection
- class jaxgarden.models.LlamaMLP(*args: Any, **kwargs: Any)[source]
Bases:
Module
LLama’s MLP implementation with SwiGLU activation.
This implements the SwiGLU MLP used in LLama models: down_proj(silu(gate_proj(x)) * up_proj(x))
- gate_proj
Linear projection for gate
- up_proj
Linear projection for up-projection
- down_proj
Linear projection for down-projection
- class jaxgarden.models.LlamaRMSNorm(*args: Any, **kwargs: Any)[source]
Bases:
Module
Root Mean Square Layer Normalization.
This implementation follows the RMSNorm paper: https://arxiv.org/abs/1910.07467 Instead of using mean and variance like traditional LayerNorm, RMSNorm only uses the root mean square of the inputs for normalization.
- norm_weights
The learned scale parameters
- norm_eps
Small constant for numerical stability
- class jaxgarden.models.LlamaRotaryEmbedding(*args: Any, **kwargs: Any)[source]
Bases:
Module
Rotary Position Embedding (RoPE) implementation for LLama.
Based on: https://arxiv.org/abs/2104.09864
- dim
Dimension of the embeddings
- base
Base for the sinusoidal functions
- class jaxgarden.models.LlamaTransformerBlock(*args: Any, **kwargs: Any)[source]
Bases:
Module
LLama transformer block implementation.
This implements a single layer of the LLama transformer, consisting of attention, layer normalization, and MLP components.
- input_layernorm
Layer normalization before attention
- attention
Multi-headed attention
- post_attention_layernorm
Layer normalization after attention
- mlp
MLP block
- __call__(x: Array, position_ids: Array, attention_mask: Array) Array [source]
Apply transformer block to input tensor.
- Parameters:
x – Input tensor of shape [batch_size, seq_len, dim]
position_ids – Position indices of shape [batch_size, seq_len]
attention_mask – Additive attention mask of shape (batch_size, 1, q_len, kv_len) or another shape that can be broadcasted to this shape.
- Returns:
Output tensor of shape [batch_size, seq_len, dim]
- __init__(layer_idx: int, dim: int, n_heads: int, n_kv_heads: int, head_dim: int, rope_theta: float, intermediate_size: int, *, norm_eps: float = 1e-05, rngs: Rngs)[source]
Initialize transformer block.
- Parameters:
layer_idx – Index of the layer
dim – Size of hidden states
n_heads – Number of attention heads
n_kv_heads – Number of key/value heads
head_dim – Dimension of each attention head
rope_theta – Base for rotary position embeddings
intermediate_size – Size of MLP intermediate layer
norm_eps – Epsilon for layer normalization
rngs – PRNG key collection
- class jaxgarden.models.ModernBERTEncoder(*args: Any, **kwargs: Any)[source]
Bases:
Module
ModernBERT encoder consisting of multiple transformer layers.
- __call__(hidden_states: Array, attention_mask: Array | None = None, sliding_window_mask: Array | None = None, position_ids: Array | None = None, deterministic: bool = True, output_attentions: bool = False, output_hidden_states: bool = False) tuple[Array] | tuple[Array, list[Array]] | tuple[Array, list[Array], list[Array]] [source]
Apply transformer encoder.
- Parameters:
hidden_states – Input tensor
attention_mask – Optional attention mask
sliding_window_mask – Optional sliding window mask
position_ids – Optional position ids
deterministic – Whether to apply dropout
output_attentions – Whether to return attention weights
output_hidden_states – Whether to return all hidden states
- Returns:
Output tensor
All hidden states (optional, if output_hidden_states=True)
All attention weights (optional, if output_attentions=True)
- Return type:
Tuple of
- __init__(rngs: Rngs, hidden_size: int, num_attention_heads: int, intermediate_size: int, num_hidden_layers: int, attention_dropout: float = 0.0, hidden_dropout: float = 0.0, attention_bias: bool = True, norm_eps: float = 1e-12, norm_bias: bool = True, global_rope_theta: float = 10000.0, max_position_embeddings: int = 4096, local_attention: tuple[int, int] = (-1, -1), local_rope_theta: float | None = None, global_attn_every_n_layers: int = 4)[source]
Initialize encoder.
- Parameters:
rngs – PRNG key collection
hidden_size – Size of hidden states
num_attention_heads – Number of attention heads
intermediate_size – Size of MLP intermediate layer
num_hidden_layers – Number of transformer layers
attention_dropout – Dropout probability for attention
hidden_dropout – Dropout probability for hidden states
attention_bias – Whether to use bias in attention
norm_eps – Epsilon for layer normalization
norm_bias – Whether to use bias in layer normalization
global_rope_theta – Base for global RoPE
max_position_embeddings – Maximum sequence length
local_attention – Tuple of (left, right) window sizes
local_rope_theta – Base for local RoPE (optional)
global_attn_every_n_layers – Apply global attention every N layers
- class jaxgarden.models.ModernBERTForMaskedLM(*args: Any, **kwargs: Any)[source]
Bases:
BaseModel
ModernBERT model with masked language modeling head.
This implements the ModernBERT architecture as described in the paper “Smarter, Better, Faster, Longer: A Modern Bidirectional Encoder for Fast, Memory Efficient, and Long Context Finetuning and Inference” by Answer.AI.
The implementation includes modern improvements such as: - Rotary Position Embeddings (RoPE) - Mixed global/local attention mechanism - Pre-LayerNorm architecture - Efficient parameter sharing
- __call__(input_ids: Array, attention_mask: Array | None = None, sliding_window_mask: Array | None = None, position_ids: Array | None = None, inputs_embeds: Array | None = None, deterministic: bool = True, output_attentions: bool = False, output_hidden_states: bool = False) dict[str, Array] [source]
Apply ModernBERT model.
- Parameters:
input_ids – Input token ids of shape [batch_size, seq_len]
attention_mask – Optional attention mask
sliding_window_mask – Optional sliding window mask
position_ids – Optional position ids
inputs_embeds – Optional pre-computed embeddings
deterministic – Whether to apply dropout
output_attentions – Whether to return attention weights
output_hidden_states – Whether to return all hidden states
- Returns:
logits: Output logits of shape [batch_size, seq_len, vocab_size]
hidden_states: All hidden states (optional)
attentions: All attention weights (optional)
- Return type:
Dictionary containing
- __init__(config: ~jaxgarden.models.modernbert.ModernBERTConfig, *, dtype: ~numpy.dtype | None = None, param_dtype: ~numpy.dtype = <class 'jax.numpy.float32'>, precision: ~jax._src.lax.lax.Precision | str | None = None, rngs: ~flax.nnx.rnglib.Rngs)[source]
Initialize ModernBERT model.
- Parameters:
config – Configuration for ModernBERT model
dtype – Data type in which computation is performed
param_dtype – Data type in which params are stored
precision – Numerical precision
rngs – Random number generators for param initialization
- class jaxgarden.models.ModernBertAttention(*args: Any, **kwargs: Any)[source]
Bases:
Module
Multi-headed self attention implementation.
This implements the standard attention mechanism with RoPE (Rotary Position Embeddings). Supports both global attention and sliding window attention.
- __call__(hidden_states: Array, attention_mask: Array | None = None, sliding_window_mask: Array | None = None, position_ids: Array | None = None, deterministic: bool = True, output_attentions: bool = False) tuple[Array] | tuple[Array, Array] [source]
Apply attention module.
- Parameters:
hidden_states – Input tensor of shape [batch_size, seq_len, hidden_size]
attention_mask – Optional attention mask
sliding_window_mask – Optional sliding window mask for local attention
position_ids – Optional position ids for RoPE
deterministic – Whether to apply dropout
output_attentions – Whether to return attention probabilities
- Returns:
Output tensor of shape [batch_size, seq_len, hidden_size]
Attention probabilities (optional) of shape [b_size, n_heads, seq_len, seq_len]
- Return type:
Tuple of
- __init__(rngs: Rngs, hidden_size: int, num_attention_heads: int, attention_dropout: float = 0.0, attention_bias: bool = True, global_rope_theta: float = 10000.0, max_position_embeddings: int = 4096, local_attention: tuple[int, int] = (-1, -1), local_rope_theta: float | None = None, layer_id: int | None = None, global_attn_every_n_layers: int = 4)[source]
Initialize attention module.
- Parameters:
rngs – PRNG key collection
hidden_size – Size of hidden states
num_attention_heads – Number of attention heads
attention_dropout – Dropout probability for attention weights
attention_bias – Whether to use bias in linear layers
global_rope_theta – Base for global RoPE
max_position_embeddings – Maximum sequence length
local_attention – Tuple of (left, right) window sizes for local attention
local_rope_theta – Base for local RoPE (optional)
layer_id – Layer index for determining attention type
global_attn_every_n_layers – Apply global attention every N layers
- class jaxgarden.models.ModernBertEmbeddings(*args: Any, **kwargs: Any)[source]
Bases:
Module
Token embeddings with normalization and dropout.
Similar to BERT embeddings but without position embeddings since we use RoPE.
- __call__(input_ids: Array, deterministic: bool = True, inputs_embeds: Array | None = None) Array [source]
Apply embeddings module.
- Parameters:
input_ids – Integer tokens of shape [batch_size, seq_len]
deterministic – Whether to apply dropout
inputs_embeds – Optional pre-computed embeddings
- Returns:
Embedded tokens with shape [batch_size, seq_len, hidden_size]
- __init__(rngs: Rngs, vocab_size: int, hidden_size: int, pad_token_id: int = 0, norm_eps: float = 1e-12, norm_bias: bool = True, embedding_dropout: float = 0.0)[source]
Initialize embeddings module.
- Parameters:
rngs – PRNG key collection
vocab_size – Size of the vocabulary
hidden_size – Size of the embeddings
pad_token_id – Token ID to use for padding
norm_eps – Epsilon for layer normalization
norm_bias – Whether to use bias in layer normalization
embedding_dropout – Dropout probability for embeddings
- class jaxgarden.models.ModernBertLayer(*args: Any, **kwargs: Any)[source]
Bases:
Module
ModernBERT transformer layer with pre-LayerNorm architecture.
This implements a transformer layer with: 1. Pre-LayerNorm for attention and MLP 2. Residual connections 3. Optional identity for first layer’s attention norm
- __call__(hidden_states: Array, attention_mask: Array | None = None, sliding_window_mask: Array | None = None, position_ids: Array | None = None, deterministic: bool = True, output_attentions: bool = False) tuple[Array] | tuple[Array, Array] [source]
Apply transformer layer.
- Parameters:
hidden_states – Input tensor of shape [batch_size, seq_len, hidden_size]
attention_mask – Optional attention mask
sliding_window_mask – Optional sliding window mask
position_ids – Optional position ids for RoPE
deterministic – Whether to apply dropout
output_attentions – Whether to return attention probabilities
- Returns:
Output tensor of shape [batch_size, seq_len, hidden_size]
Attention probabilities (optional) of shape [b_size, n_heads, seq_len, seq_len]
- Return type:
Tuple of
- __init__(rngs: Rngs, hidden_size: int, num_attention_heads: int, intermediate_size: int, layer_id: int | None = None, attention_dropout: float = 0.0, hidden_dropout: float = 0.0, attention_bias: bool = True, norm_eps: float = 1e-12, norm_bias: bool = True, global_rope_theta: float = 10000.0, max_position_embeddings: int = 4096, local_attention: tuple[int, int] = (-1, -1), local_rope_theta: float | None = None, global_attn_every_n_layers: int = 4)[source]
Initialize transformer layer.
- Parameters:
rngs – PRNG key collection
hidden_size – Size of hidden states
num_attention_heads – Number of attention heads
intermediate_size – Size of MLP intermediate layer
layer_id – Layer index (first layer uses identity for attn norm)
attention_dropout – Dropout probability for attention
hidden_dropout – Dropout probability for hidden states
attention_bias – Whether to use bias in attention
norm_eps – Epsilon for layer normalization
norm_bias – Whether to use bias in layer normalization
global_rope_theta – Base for global RoPE
max_position_embeddings – Maximum sequence length
local_attention – Tuple of (left, right) window sizes
local_rope_theta – Base for local RoPE (optional)
global_attn_every_n_layers – Apply global attention every N layers
- class jaxgarden.models.ModernBertMLP(*args: Any, **kwargs: Any)[source]
Bases:
Module
MLP with gated linear units.
Replaces the traditional intermediate + output layers with a single gated MLP.
- __call__(hidden_states: Array, deterministic: bool = True) Array [source]
Apply MLP module.
- Parameters:
hidden_states – Input tensor of shape [batch_size, seq_len, hidden_size]
deterministic – Whether to apply dropout
- Returns:
Output tensor of shape [batch_size, seq_len, hidden_size]
- __init__(rngs: Rngs, hidden_size: int, intermediate_size: int, mlp_bias: bool = True, mlp_dropout: float = 0.0)[source]
Initialize MLP module.
- Parameters:
rngs – PRNG key collection
hidden_size – Size of input and output
intermediate_size – Size of intermediate layer
mlp_bias – Whether to use bias in linear layers
mlp_dropout – Dropout probability