API Documentation

Attention Modules

Attention modules for JAXgarden.

class jaxgarden.attention.MultiHeadAttention(*args: Any, **kwargs: Any)[source]

Bases: MultiHeadAttention

Multi-head attention with support for Flash Attention.

This class extends Flax NNX’s MultiHeadAttention to support Flash Attention through JAX’s dot_product_attention implementation parameter.

Example usage:

```python import jax import jax.numpy as jnp import flax.nnx as nnx from jax_layers.attention import MultiHeadAttention

# Create a MultiHeadAttention module with Flash Attention support attention = MultiHeadAttention(

num_heads=8, in_features=512, implementation=”cudnn”, # Use cuDNN’s Flash Attention if available rngs=nnx.Rngs(0),

)

# Initialize parameters key = jax.random.PRNGKey(0) x = jax.random.normal(key, (2, 128, 512)) # (batch, seq_length, hidden_dim)

# Create a causal attention mask mask = jnp.tril(jnp.ones((2, 1, 128, 128))) # (batch, 1, q_len, kv_len)

# Apply the model output = attention(x, mask=mask) ```

__init__(num_heads: int, in_features: int, qkv_features: int | None = None, out_features: int | None = None, *, dtype: ~numpy.dtype | None = None, param_dtype: ~numpy.dtype = <class 'jax.numpy.float32'>, broadcast_dropout: bool = True, dropout_rate: float = 0.0, deterministic: bool | None = None, precision: ~jax._src.lax.lax.Precision | str | None = None, kernel_init: ~collections.abc.Callable = <function variance_scaling.<locals>.init>, out_kernel_init: ~collections.abc.Callable | None = None, bias_init: ~collections.abc.Callable = <function zeros>, out_bias_init: ~collections.abc.Callable | None = None, use_bias: bool = True, attention_fn: ~collections.abc.Callable | None = None, decode: bool | None = None, normalize_qk: bool = False, qkv_dot_general: ~collections.abc.Callable | None = None, out_dot_general: ~collections.abc.Callable | None = None, qkv_dot_general_cls: type | None = None, out_dot_general_cls: type | None = None, implementation: ~typing.Literal['xla', 'cudnn', 'flash'] | None = None, rngs: ~flax.nnx.rnglib.Rngs)[source]

Initialize the MultiHeadAttention module.

Parameters:
  • num_heads – number of attention heads.

  • in_features – int or tuple with number of input features.

  • qkv_features – dimension of the key, query, and value.

  • out_features – dimension of the last projection.

  • dtype – the dtype of the computation.

  • param_dtype – the dtype passed to parameter initializers.

  • broadcast_dropout – bool: use a broadcasted dropout along batch dims.

  • dropout_rate – dropout rate.

  • deterministic – if false, the attention weight is masked randomly using dropout.

  • precision – numerical precision of the computation.

  • kernel_init – initializer for the kernel of the Dense layers.

  • out_kernel_init – initializer for the kernel of the output Dense layer.

  • bias_init – initializer for the bias of the Dense layers.

  • out_bias_init – initializer for the bias of the output Dense layer.

  • use_bias – bool: whether pointwise QKVO dense transforms use bias.

  • attention_fn – dot_product_attention or compatible function.

  • decode – whether to prepare and use an autoregressive cache.

  • normalize_qk – should QK normalization be applied.

  • qkv_dot_general – dot_general function for QKV projection.

  • out_dot_general – dot_general function for output projection.

  • qkv_dot_general_cls – dot_general class for QKV projection.

  • out_dot_general_cls – dot_general class for output projection.

  • implementation – which implementation to use for attention. Options are: - “xla”: Use XLA’s default implementation - “cudnn”: Use cuDNN’s Flash Attention implementation (if available) - “flash”: Alias for “cudnn” - None: Automatically select the best available implementation

  • rngs – random number generator keys.

Functional Interfaces

Functional implementations for JAXgarden.

jaxgarden.functional.dot_product_attention(query: Array, key: Array, value: Array, bias: Array | None = None, mask: Array | None = None, broadcast_dropout: bool = True, dropout_rng: Array | None = None, dropout_rate: float = 0.0, deterministic: bool = False, dtype: dtype | None = None, precision: Precision | str | None = None, implementation: Literal['xla', 'cudnn', 'flash'] | None = None, module: object | None = None) Array[source]

Computes dot-product attention with optional Flash Attention support.

This function provides a wrapper around JAX’s dot_product_attention with the option to use Flash Attention when available. It follows the Flax NNX interface while allowing the use of different implementations through the implementation parameter.

Parameters:
  • query – queries for calculating attention with shape of [batch…, q_length, num_heads, qk_depth_per_head].

  • key – keys for calculating attention with shape of [batch…, kv_length, num_heads, qk_depth_per_head].

  • value – values to be used in attention with shape of [batch…, kv_length, num_heads, v_depth_per_head].

  • bias – bias for the attention weights. This should be broadcastable to the shape [batch…, num_heads, q_length, kv_length].

  • mask – mask for the attention weights. This should be broadcastable to the shape [batch…, num_heads, q_length, kv_length].

  • broadcast_dropout – bool: use a broadcasted dropout along batch dims.

  • dropout_rng – JAX PRNGKey: to be used for dropout.

  • dropout_rate – dropout rate.

  • deterministic – bool, deterministic or not (to apply dropout).

  • dtype – the dtype of the computation (default: infer from inputs).

  • precision – numerical precision of the computation.

  • implementation – which implementation to use. Options are: - “xla”: Use XLA’s default implementation - “cudnn”: Use cuDNN’s Flash Attention implementation (if available) - “flash”: Alias for “cudnn” - None: Automatically select the best available implementation

  • module – the Module that will sow the attention weights.

Returns:

Output of shape [batch…, q_length, num_heads, v_depth_per_head].

Models

class jaxgarden.models.BaseConfig(seed: int = 42, log_level: str = 'info', extra: dict[str, ~typing.Any] = <factory>)[source]

Bases: object

Base configuration for all the models implemented in the JAXgarden library.

Each model implemented in JAXgarden should subclass this class for configuration management.

extra: dict[str, Any]
log_level: str = 'info'
seed: int = 42
to_dict() dict[str, Any][source]
update(**kwargs: dict) None[source]
class jaxgarden.models.BaseModel(*args: Any, **kwargs: Any)[source]

Bases: Module

Base class for all the models implemented in the JAXgarden library.

__init__(config: ~jaxgarden.models.base.BaseConfig, *, dtype: ~numpy.dtype | None = None, param_dtype: ~numpy.dtype = <class 'jax.numpy.float32'>, precision: ~jax._src.lax.lax.Precision | str | None = None, rngs: ~flax.nnx.rnglib.Rngs)[source]

Initialize the model.

Args:

config: config class for this model. dtype: Data type in which computation is performed.

param_dtype: Data type in which params are stored.

precision: Numerical precision. rngs: Random number generators for param initialization etc.

convert_weights_from_hf(state: State, weights: Iterator[tuple[Any, Any]]) None[source]

Convert weights from Hugging Face Hub to the model’s state.

This method should be implemented in downstream classes to support conversion from HuggingFace format.

static download_from_hf(repo_id: str, local_dir: str, token: str | None = None, force_download: bool = False) None[source]

Downloads the model from the Hugging Face Hub.

Parameters:
  • repo_id – The repository ID of the model to download.

  • local_dir – The local directory to save the model to.

from_hf(model_repo_or_id: str, token: str | None = None, force_download: bool = False, save_in_orbax: bool = True, remove_hf_after_conversion: bool = True) None[source]

Downloads the model from the Hugging Face Hub and returns a new instance of the model.

It can also save the converted weights in an Orbax checkpoint

and removes the original HF checkpoint after conversion.

Parameters:
  • model_repo_or_id – The repository ID or name of the model to download.

  • token – The token to use for authentication with the Hugging Face Hub.

  • save_in_orbax – Whether to save the converted weights in an Orbax checkpoint.

  • remove_hf_after_conversion – Whether to remove the downloaded HuggingFace checkpoint after conversion.

static iter_safetensors(path_to_model_weights: str) Iterator[tuple[Any, Any]][source]

Helper function to lazily load params from safetensors file.

Use this static method to iterate over weights for conversion tasks.

Parameters:

model_path_to_params – Path to directory containing .safetensors files.

load(path: str) Module[source]

Loads the model state from a directory.

Parameters:

path – The directory path to load the model state from.

save(path: str) None[source]

Saves the model state to a directory.

Parameters:

path – The directory path to save the model state to.

property state: State

Splits state from the graph and returns it

property state_dict: dict[str, Array]

Splits state from the graph and returns it as a dictionary.

It can be used for serialization with orbax.

class jaxgarden.models.GenerationMixin[source]

Bases: object

Mixin that adds text generation capabilities, including sampling with temperature, top-k, top-p, and min-probability filtering, for CausalLMs.

generate(input_ids: Array, attention_mask: Array | None = None, max_length: int = 20, temperature: float = 1.0, top_k: int | None = None, top_p: float | None = None, min_p: float | None = None, do_sample: bool = True, pad_token_id: int | None = None, eos_token_id: int | None = None, rng: Array | None = None, use_jit: bool = False) Array[source]

Generate tokens autoregressively with various sampling methods.

Parameters:
  • input_ids – Initial token IDs of shape [batch_size, seq_len].

  • max_length – Maximum length for generated sequences.

  • temperature – Temperature for sampling.

  • top_k – If specified, only sample from the top-k logits.

  • top_p – If specified, only sample from the smallest set of logits whose cumulative probability exceeds p.

  • min_p – If specified, only consider logits with prob >= min_p * prob(max logit).

  • do_sample – If True, use sampling; otherwise use greedy/beam search.

  • pad_token_id – Token ID to use for padding.

  • eos_token_id – Token ID that signals the end of generation.

  • rng – Optional PRNG key for sampling.

  • use_jit – If True, use jax.jit to compile the generate function.

Returns:

Generated token IDs of shape [batch_size, max_length].

class jaxgarden.models.LlamaAttention(*args: Any, **kwargs: Any)[source]

Bases: Module

Multi-headed attention with support for Group Query Attention (GQA).

This implements the LLama attention mechanism with rotary position embeddings (RoPE) and support for fewer key-value heads than query heads (GQA).

q_proj

Linear projection for queries

k_proj

Linear projection for keys

v_proj

Linear projection for values

o_proj

Linear projection for output

rotary_emb

Rotary position embeddings

head_dim

Dimension of each attention head

n_heads

Number of attention heads

n_kv_heads

Number of key/value heads

__call__(x: Array, position_ids: Array, attention_mask: Array) Array[source]

Apply self-attention using queries, keys, and values derived from input x.

Parameters:
  • x – Input tensor of shape [batch_size, seq_len, dim]

  • position_ids – Position indices of shape [batch_size, seq_len]

  • attention_mask – Additive attention mask of shape (batch_size, 1, q_len, kv_len) or a shape that can be broadcasted to this shape.

Returns:

Output tensor of shape [batch_size, seq_len, dim]

__init__(layer_idx: int, dim: int, n_heads: int, n_kv_heads: int, head_dim: int, rope_theta: float, *, rngs: Rngs)[source]

Initialize attention module.

Parameters:
  • layer_idx – Index of the layer

  • dim – Size of hidden states

  • n_heads – Number of attention heads

  • n_kv_heads – Number of key/value heads

  • head_dim – Dimension of each attention head

  • rope_theta – Base for rotary position embeddings

  • rngs – PRNG key collection

apply_rotary_pos_emb(q: Array, k: Array, cos: Array, sin: Array, unsqueeze_dim: int = 1) tuple[Array, Array][source]

Apply rotary position embeddings to query and key tensors.

Parameters:
  • q – Query tensor

  • k – Key tensor

  • cos – Cosine component of rotary embeddings

  • sin – Sine component of rotary embeddings

  • unsqueeze_dim – Dimension to unsqueeze cosine and sine components

Returns:

Tuple of (rotated_q, rotated_k)

repeat_kv(hidden_states: Array, n_repeat: int) Array[source]

Repeat key/value heads to match the number of query heads.

When using GQA, we need to repeat each key/value head to match the number of query heads.

Parameters:
  • hidden_states – Key or value tensor of shape [batch, n_kv_heads, seq_len, head_dim]

  • n_repeat – Number of times to repeat each key/value head

Returns:

Tensor with repeated key/value heads

rotate_half(x: Array) Array[source]

Rotate half the hidden dims of the input.

Parameters:

x – Input tensor

Returns:

Tensor with half the dimensions rotated

class jaxgarden.models.LlamaConfig(seed: int = 42, log_level: str = 'info', extra: dict[str, ~typing.Any] = <factory>, dim: int = 2048, n_layers: int = 16, n_heads: int = 32, n_kv_heads: int = 8, head_dim: int = 64, intermediate_size: int = 14336, vocab_size: int = 128256, multiple_of: int = 256, norm_eps: float = 1e-05, rope_theta: float = 500000.0)[source]

Bases: BaseConfig

Configuration for LLama model.

This configuration class extends BaseConfig and contains all the parameters required to initialize a LLama model. It includes settings for model architecture, attention mechanisms, and other hyperparameters.

dim

Size of hidden states

Type:

int

n_layers

Number of transformer layers

Type:

int

n_heads

Number of attention heads

Type:

int

n_kv_heads

Number of key/value heads (for group query attention)

Type:

int

head_dim

Dimension of each attention head

Type:

int

intermediate_size

Size of MLP intermediate layer

Type:

int

vocab_size

Size of vocabulary

Type:

int

multiple_of

Ensure dimensions are multiples of this value

Type:

int

norm_eps

Epsilon for layer normalization

Type:

float

rope_theta

Base for rotary position embeddings

Type:

float

dim: int = 2048
head_dim: int = 64
intermediate_size: int = 14336
multiple_of: int = 256
n_heads: int = 32
n_kv_heads: int = 8
n_layers: int = 16
norm_eps: float = 1e-05
rope_theta: float = 500000.0
vocab_size: int = 128256
class jaxgarden.models.LlamaForCausalLM(*args: Any, **kwargs: Any)[source]

Bases: BaseModel, GenerationMixin

LLama model for causal language modeling.

This implements the full LLama model for generating text. It consists of token embeddings, transformer layers, and language modeling head.

token_embed

Token embedding layer

layers

List of transformer blocks

lm_head

Linear layer for language modeling

norm

Final layer normalization

__call__(input_ids: Array, attention_mask: Array | None = None, deterministic: bool = True) Array[source]

Forward pass of the LLama model.

Parameters:
  • input_ids – Input token ids of shape [batch_size, seq_len]

  • attention_mask – Boolean attention mask of shape [batch_size, seq_len]

Returns:

Logits for next token prediction of shape [batch_size, seq_len, vocab_size]

__init__(config: ~jaxgarden.models.llama.LlamaConfig, *, dtype: ~numpy.dtype | None = None, param_dtype: ~numpy.dtype = <class 'jax.numpy.float32'>, precision: ~jax._src.lax.lax.Precision | str | None = None, rngs: ~flax.nnx.rnglib.Rngs)[source]

Initialize LlamaForCausalLM.

Parameters:
  • config – Model configuration

  • dtype – Data type for computation

  • param_dtype – Data type for parameters

  • precision – Precision for matrix multiplication

  • rngs – PRNG key collection

convert_weights_from_hf(state: State | dict[str, Array], weights: Iterator[tuple[Any, Any]]) None[source]

Convert weights from Hugging Face Hub to the model’s state.

This method should be implemented in downstream classes to support conversion from HuggingFace format.

class jaxgarden.models.LlamaMLP(*args: Any, **kwargs: Any)[source]

Bases: Module

LLama’s MLP implementation with SwiGLU activation.

This implements the SwiGLU MLP used in LLama models: down_proj(silu(gate_proj(x)) * up_proj(x))

gate_proj

Linear projection for gate

up_proj

Linear projection for up-projection

down_proj

Linear projection for down-projection

__call__(x: Array) Array[source]

Apply SwiGLU MLP to input tensor.

Parameters:

x – Input tensor of shape [batch_size, seq_len, dim]

Returns:

Output tensor of shape [batch_size, seq_len, dim]

__init__(layer_idx: int, dim: int, intermediate_size: int, *, rngs: Rngs)[source]

Initialize MLP module.

Parameters:
  • layer_idx – Index of the layer

  • dim – Size of hidden states

  • intermediate_size – Size of intermediate layer

  • rngs – PRNG key collection

class jaxgarden.models.LlamaRMSNorm(*args: Any, **kwargs: Any)[source]

Bases: Module

Root Mean Square Layer Normalization.

This implementation follows the RMSNorm paper: https://arxiv.org/abs/1910.07467 Instead of using mean and variance like traditional LayerNorm, RMSNorm only uses the root mean square of the inputs for normalization.

norm_weights

The learned scale parameters

norm_eps

Small constant for numerical stability

__call__(hidden_states: Array) Array[source]

Apply RMS normalization to input tensor.

Parameters:

hidden_states – Input tensor of shape […, dim]

Returns:

Normalized tensor with same shape as input

__init__(dim: int, *, norm_eps: float = 1e-05, rngs: Rngs)[source]

Initialize RMSNorm module.

Parameters:
  • dim – Dimension of the input tensor

  • norm_eps – Small constant for numerical stability

  • rngs – PRNG key collection

class jaxgarden.models.LlamaRotaryEmbedding(*args: Any, **kwargs: Any)[source]

Bases: Module

Rotary Position Embedding (RoPE) implementation for LLama.

Based on: https://arxiv.org/abs/2104.09864

dim

Dimension of the embeddings

base

Base for the sinusoidal functions

__call__(position_ids: Array) tuple[Array, Array][source]

Generate rotary embeddings from position ids.

Parameters:

position_ids – Position indices of shape [batch_size, seq_len]

Returns:

Tuple of (cos, sin) tensors for rotary embeddings

__init__(dim: int, *, base: float = 10000.0, rngs: Rngs)[source]

Initialize RoPE module.

Parameters:
  • dim – Dimension of the embeddings (must be even)

  • base – Base for the sinusoidal functions

  • rngs – PRNG key collection

class jaxgarden.models.LlamaTransformerBlock(*args: Any, **kwargs: Any)[source]

Bases: Module

LLama transformer block implementation.

This implements a single layer of the LLama transformer, consisting of attention, layer normalization, and MLP components.

input_layernorm

Layer normalization before attention

attention

Multi-headed attention

post_attention_layernorm

Layer normalization after attention

mlp

MLP block

__call__(x: Array, position_ids: Array, attention_mask: Array) Array[source]

Apply transformer block to input tensor.

Parameters:
  • x – Input tensor of shape [batch_size, seq_len, dim]

  • position_ids – Position indices of shape [batch_size, seq_len]

  • attention_mask – Additive attention mask of shape (batch_size, 1, q_len, kv_len) or another shape that can be broadcasted to this shape.

Returns:

Output tensor of shape [batch_size, seq_len, dim]

__init__(layer_idx: int, dim: int, n_heads: int, n_kv_heads: int, head_dim: int, rope_theta: float, intermediate_size: int, *, norm_eps: float = 1e-05, rngs: Rngs)[source]

Initialize transformer block.

Parameters:
  • layer_idx – Index of the layer

  • dim – Size of hidden states

  • n_heads – Number of attention heads

  • n_kv_heads – Number of key/value heads

  • head_dim – Dimension of each attention head

  • rope_theta – Base for rotary position embeddings

  • intermediate_size – Size of MLP intermediate layer

  • norm_eps – Epsilon for layer normalization

  • rngs – PRNG key collection

class jaxgarden.models.ModernBERTEncoder(*args: Any, **kwargs: Any)[source]

Bases: Module

ModernBERT encoder consisting of multiple transformer layers.

__call__(hidden_states: Array, attention_mask: Array | None = None, sliding_window_mask: Array | None = None, position_ids: Array | None = None, deterministic: bool = True, output_attentions: bool = False, output_hidden_states: bool = False) tuple[Array] | tuple[Array, list[Array]] | tuple[Array, list[Array], list[Array]][source]

Apply transformer encoder.

Parameters:
  • hidden_states – Input tensor

  • attention_mask – Optional attention mask

  • sliding_window_mask – Optional sliding window mask

  • position_ids – Optional position ids

  • deterministic – Whether to apply dropout

  • output_attentions – Whether to return attention weights

  • output_hidden_states – Whether to return all hidden states

Returns:

  • Output tensor

  • All hidden states (optional, if output_hidden_states=True)

  • All attention weights (optional, if output_attentions=True)

Return type:

Tuple of

__init__(rngs: Rngs, hidden_size: int, num_attention_heads: int, intermediate_size: int, num_hidden_layers: int, attention_dropout: float = 0.0, hidden_dropout: float = 0.0, attention_bias: bool = True, norm_eps: float = 1e-12, norm_bias: bool = True, global_rope_theta: float = 10000.0, max_position_embeddings: int = 4096, local_attention: tuple[int, int] = (-1, -1), local_rope_theta: float | None = None, global_attn_every_n_layers: int = 4)[source]

Initialize encoder.

Parameters:
  • rngs – PRNG key collection

  • hidden_size – Size of hidden states

  • num_attention_heads – Number of attention heads

  • intermediate_size – Size of MLP intermediate layer

  • num_hidden_layers – Number of transformer layers

  • attention_dropout – Dropout probability for attention

  • hidden_dropout – Dropout probability for hidden states

  • attention_bias – Whether to use bias in attention

  • norm_eps – Epsilon for layer normalization

  • norm_bias – Whether to use bias in layer normalization

  • global_rope_theta – Base for global RoPE

  • max_position_embeddings – Maximum sequence length

  • local_attention – Tuple of (left, right) window sizes

  • local_rope_theta – Base for local RoPE (optional)

  • global_attn_every_n_layers – Apply global attention every N layers

class jaxgarden.models.ModernBERTForMaskedLM(*args: Any, **kwargs: Any)[source]

Bases: BaseModel

ModernBERT model with masked language modeling head.

This implements the ModernBERT architecture as described in the paper “Smarter, Better, Faster, Longer: A Modern Bidirectional Encoder for Fast, Memory Efficient, and Long Context Finetuning and Inference” by Answer.AI.

The implementation includes modern improvements such as: - Rotary Position Embeddings (RoPE) - Mixed global/local attention mechanism - Pre-LayerNorm architecture - Efficient parameter sharing

__call__(input_ids: Array, attention_mask: Array | None = None, sliding_window_mask: Array | None = None, position_ids: Array | None = None, inputs_embeds: Array | None = None, deterministic: bool = True, output_attentions: bool = False, output_hidden_states: bool = False) dict[str, Array][source]

Apply ModernBERT model.

Parameters:
  • input_ids – Input token ids of shape [batch_size, seq_len]

  • attention_mask – Optional attention mask

  • sliding_window_mask – Optional sliding window mask

  • position_ids – Optional position ids

  • inputs_embeds – Optional pre-computed embeddings

  • deterministic – Whether to apply dropout

  • output_attentions – Whether to return attention weights

  • output_hidden_states – Whether to return all hidden states

Returns:

  • logits: Output logits of shape [batch_size, seq_len, vocab_size]

  • hidden_states: All hidden states (optional)

  • attentions: All attention weights (optional)

Return type:

Dictionary containing

__init__(config: ~jaxgarden.models.modernbert.ModernBERTConfig, *, dtype: ~numpy.dtype | None = None, param_dtype: ~numpy.dtype = <class 'jax.numpy.float32'>, precision: ~jax._src.lax.lax.Precision | str | None = None, rngs: ~flax.nnx.rnglib.Rngs)[source]

Initialize ModernBERT model.

Parameters:
  • config – Configuration for ModernBERT model

  • dtype – Data type in which computation is performed

  • param_dtype – Data type in which params are stored

  • precision – Numerical precision

  • rngs – Random number generators for param initialization

class jaxgarden.models.ModernBertAttention(*args: Any, **kwargs: Any)[source]

Bases: Module

Multi-headed self attention implementation.

This implements the standard attention mechanism with RoPE (Rotary Position Embeddings). Supports both global attention and sliding window attention.

__call__(hidden_states: Array, attention_mask: Array | None = None, sliding_window_mask: Array | None = None, position_ids: Array | None = None, deterministic: bool = True, output_attentions: bool = False) tuple[Array] | tuple[Array, Array][source]

Apply attention module.

Parameters:
  • hidden_states – Input tensor of shape [batch_size, seq_len, hidden_size]

  • attention_mask – Optional attention mask

  • sliding_window_mask – Optional sliding window mask for local attention

  • position_ids – Optional position ids for RoPE

  • deterministic – Whether to apply dropout

  • output_attentions – Whether to return attention probabilities

Returns:

  • Output tensor of shape [batch_size, seq_len, hidden_size]

  • Attention probabilities (optional) of shape [b_size, n_heads, seq_len, seq_len]

Return type:

Tuple of

__init__(rngs: Rngs, hidden_size: int, num_attention_heads: int, attention_dropout: float = 0.0, attention_bias: bool = True, global_rope_theta: float = 10000.0, max_position_embeddings: int = 4096, local_attention: tuple[int, int] = (-1, -1), local_rope_theta: float | None = None, layer_id: int | None = None, global_attn_every_n_layers: int = 4)[source]

Initialize attention module.

Parameters:
  • rngs – PRNG key collection

  • hidden_size – Size of hidden states

  • num_attention_heads – Number of attention heads

  • attention_dropout – Dropout probability for attention weights

  • attention_bias – Whether to use bias in linear layers

  • global_rope_theta – Base for global RoPE

  • max_position_embeddings – Maximum sequence length

  • local_attention – Tuple of (left, right) window sizes for local attention

  • local_rope_theta – Base for local RoPE (optional)

  • layer_id – Layer index for determining attention type

  • global_attn_every_n_layers – Apply global attention every N layers

class jaxgarden.models.ModernBertEmbeddings(*args: Any, **kwargs: Any)[source]

Bases: Module

Token embeddings with normalization and dropout.

Similar to BERT embeddings but without position embeddings since we use RoPE.

__call__(input_ids: Array, deterministic: bool = True, inputs_embeds: Array | None = None) Array[source]

Apply embeddings module.

Parameters:
  • input_ids – Integer tokens of shape [batch_size, seq_len]

  • deterministic – Whether to apply dropout

  • inputs_embeds – Optional pre-computed embeddings

Returns:

Embedded tokens with shape [batch_size, seq_len, hidden_size]

__init__(rngs: Rngs, vocab_size: int, hidden_size: int, pad_token_id: int = 0, norm_eps: float = 1e-12, norm_bias: bool = True, embedding_dropout: float = 0.0)[source]

Initialize embeddings module.

Parameters:
  • rngs – PRNG key collection

  • vocab_size – Size of the vocabulary

  • hidden_size – Size of the embeddings

  • pad_token_id – Token ID to use for padding

  • norm_eps – Epsilon for layer normalization

  • norm_bias – Whether to use bias in layer normalization

  • embedding_dropout – Dropout probability for embeddings

class jaxgarden.models.ModernBertLayer(*args: Any, **kwargs: Any)[source]

Bases: Module

ModernBERT transformer layer with pre-LayerNorm architecture.

This implements a transformer layer with: 1. Pre-LayerNorm for attention and MLP 2. Residual connections 3. Optional identity for first layer’s attention norm

__call__(hidden_states: Array, attention_mask: Array | None = None, sliding_window_mask: Array | None = None, position_ids: Array | None = None, deterministic: bool = True, output_attentions: bool = False) tuple[Array] | tuple[Array, Array][source]

Apply transformer layer.

Parameters:
  • hidden_states – Input tensor of shape [batch_size, seq_len, hidden_size]

  • attention_mask – Optional attention mask

  • sliding_window_mask – Optional sliding window mask

  • position_ids – Optional position ids for RoPE

  • deterministic – Whether to apply dropout

  • output_attentions – Whether to return attention probabilities

Returns:

  • Output tensor of shape [batch_size, seq_len, hidden_size]

  • Attention probabilities (optional) of shape [b_size, n_heads, seq_len, seq_len]

Return type:

Tuple of

__init__(rngs: Rngs, hidden_size: int, num_attention_heads: int, intermediate_size: int, layer_id: int | None = None, attention_dropout: float = 0.0, hidden_dropout: float = 0.0, attention_bias: bool = True, norm_eps: float = 1e-12, norm_bias: bool = True, global_rope_theta: float = 10000.0, max_position_embeddings: int = 4096, local_attention: tuple[int, int] = (-1, -1), local_rope_theta: float | None = None, global_attn_every_n_layers: int = 4)[source]

Initialize transformer layer.

Parameters:
  • rngs – PRNG key collection

  • hidden_size – Size of hidden states

  • num_attention_heads – Number of attention heads

  • intermediate_size – Size of MLP intermediate layer

  • layer_id – Layer index (first layer uses identity for attn norm)

  • attention_dropout – Dropout probability for attention

  • hidden_dropout – Dropout probability for hidden states

  • attention_bias – Whether to use bias in attention

  • norm_eps – Epsilon for layer normalization

  • norm_bias – Whether to use bias in layer normalization

  • global_rope_theta – Base for global RoPE

  • max_position_embeddings – Maximum sequence length

  • local_attention – Tuple of (left, right) window sizes

  • local_rope_theta – Base for local RoPE (optional)

  • global_attn_every_n_layers – Apply global attention every N layers

class jaxgarden.models.ModernBertMLP(*args: Any, **kwargs: Any)[source]

Bases: Module

MLP with gated linear units.

Replaces the traditional intermediate + output layers with a single gated MLP.

__call__(hidden_states: Array, deterministic: bool = True) Array[source]

Apply MLP module.

Parameters:
  • hidden_states – Input tensor of shape [batch_size, seq_len, hidden_size]

  • deterministic – Whether to apply dropout

Returns:

Output tensor of shape [batch_size, seq_len, hidden_size]

__init__(rngs: Rngs, hidden_size: int, intermediate_size: int, mlp_bias: bool = True, mlp_dropout: float = 0.0)[source]

Initialize MLP module.

Parameters:
  • rngs – PRNG key collection

  • hidden_size – Size of input and output

  • intermediate_size – Size of intermediate layer

  • mlp_bias – Whether to use bias in linear layers

  • mlp_dropout – Dropout probability