"""ModernBERT implementation in JAX using Flax NNX.
This module implements the ModernBERT architecture as described in the paper
"Smarter, Better, Faster, Longer: A Modern Bidirectional Encoder
for Fast, Memory Efficient, and Long Context Finetuning and Inference" by Answer.AI.
The implementation includes modern improvements such as RoPE and global/local attention mechanisms.
See: https://arxiv.org/abs/2412.13663
"""
from dataclasses import dataclass
import flax.nnx as nnx
import jax
import jax.numpy as jnp
from jaxgarden.models.base import BaseConfig, BaseModel
class Identity(nnx.Module):
"""Identity layer that simply returns its input."""
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
return x
def create_sinusoidal_positions(max_length: int, dim: int, base: float = 10000.0) -> jnp.ndarray:
"""Create sinusoidal position embeddings.
Args:
max_length: Maximum sequence length
dim: Dimension of the embeddings (must be even)
base: Base for the sinusoidal functions
Returns:
Array of shape (max_length, dim//2, 2) containing (cos, sin) values
"""
# Create position indices [0, 1, ..., max_length-1]
positions = jnp.arange(max_length, dtype=jnp.float32)
# Create dimension indices [0, 2, ..., dim-2] for half the dimensions
dim_indices = jnp.arange(0, dim, 2, dtype=jnp.float32)
# Calculate theta: base^(-2i/dim)
theta = 1.0 / (base ** (dim_indices / dim))
# Calculate angles: pos * theta using einsum
# Shape: (max_length, dim//2)
angles = jnp.einsum("i,j->ij", positions, theta)
# Return stacked (cos, sin) tuple of shape (max_length, dim//2, 2)
return jnp.stack([jnp.cos(angles), jnp.sin(angles)], axis=-1)
def apply_rotary_pos_emb(
x: jnp.ndarray, cache: jnp.ndarray, positions: jnp.ndarray | None = None
) -> jnp.ndarray:
"""Apply rotary position embeddings to input tensor.
Args:
x: Input tensor of shape (batch_size, seq_len, num_heads, head_dim)
cache: Cache tensor of shape (max_seq_len, head_dim//2, 2) containing (cos, sin)
positions: Optional position indices of shape (batch_size, seq_len)
Returns:
Tensor with rotary embeddings applied
"""
# Get sequence length from input
seq_len = x.shape[1]
# Get cache values based on positions
if positions is not None:
rope_cache = cache[positions] # [batch, seq, dim//2, 2]
else:
# Take exactly seq_len positions and repeat if needed
positions = jnp.arange(seq_len) % cache.shape[0]
rope_cache = cache[positions] # [seq, dim//2, 2]
# Add batch dimension for broadcasting
rope_cache = jnp.expand_dims(rope_cache, 0) # [1, seq, dim//2, 2]
# Reshape input for rotation
x_reshaped = x.reshape(*x.shape[:-1], -1, 2) # [batch, seq, heads, dim//2, 2]
# Add head dimension to rope_cache for broadcasting
rope_cache = jnp.expand_dims(rope_cache, 2) # [batch/1, seq, 1, dim//2, 2]
# Apply rotation using complex multiplication
x_out = jnp.stack(
[
x_reshaped[..., 0] * rope_cache[..., 0] - x_reshaped[..., 1] * rope_cache[..., 1],
x_reshaped[..., 1] * rope_cache[..., 0] + x_reshaped[..., 0] * rope_cache[..., 1],
],
axis=-1,
)
# Reshape back to original shape
return x_out.reshape(x.shape)
class RoPEPositionalEmbedding(nnx.Module):
"""Rotary Position Embedding (RoPE) implementation.
Based on https://arxiv.org/abs/2104.09864
"""
def __init__(
self, rngs: nnx.Rngs, dim: int, max_position_embeddings: int = 4096, base: float = 10000.0
):
"""Initialize RoPE module.
Args:
rngs: PRNG key collection
dim: Dimension of the embeddings (must be even)
max_position_embeddings: Maximum sequence length to cache
base: Base for the sinusoidal functions
"""
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
# Create and store the cos/sin cache
self.cache = create_sinusoidal_positions(self.max_position_embeddings, self.dim, self.base)
def __call__(
self,
x: jnp.ndarray,
positions: jnp.ndarray | None = None,
) -> jnp.ndarray:
"""Apply rotary position embeddings to input tensor.
Args:
x: Input tensor of shape [batch_size, seq_len, num_heads, head_dim]
positions: Optional position ids of shape [batch_size, seq_len]
Returns:
Tensor with position embeddings applied
"""
return apply_rotary_pos_emb(x, self.cache, positions)
def create_sliding_window_mask(
seq_len: int, window_size: tuple[int, int], dtype: jnp.dtype = jnp.float32
) -> jnp.ndarray:
"""Create a sliding window attention mask.
Args:
seq_len: Length of the sequence
window_size: Tuple of (left, right) window sizes
dtype: Data type of the mask
Returns:
Mask tensor of shape [1, 1, seq_len, seq_len] with -inf outside window
"""
# Create position indices
positions = jnp.arange(seq_len)
# Create relative position matrix [seq_len, seq_len]
relative_positions = positions[:, None] - positions[None, :]
# Create window mask
left_window, right_window = window_size
mask = (relative_positions >= -left_window) & (relative_positions <= right_window)
# Convert to float and replace False with -inf
mask = mask.astype(dtype)
mask = jnp.where(mask, 0.0, -jnp.inf)
# Add batch and head dimensions [1, 1, seq_len, seq_len]
return mask[None, None, :, :]
[docs]
class ModernBertAttention(nnx.Module):
"""Multi-headed self attention implementation.
This implements the standard attention mechanism with RoPE (Rotary Position Embeddings).
Supports both global attention and sliding window attention.
"""
[docs]
def __init__(
self,
rngs: nnx.Rngs,
hidden_size: int,
num_attention_heads: int,
attention_dropout: float = 0.0,
attention_bias: bool = True,
global_rope_theta: float = 10000.0,
max_position_embeddings: int = 4096,
local_attention: tuple[int, int] = (-1, -1), # (-1, -1) means global attention
local_rope_theta: float | None = None,
layer_id: int | None = None,
global_attn_every_n_layers: int = 4,
):
"""Initialize attention module.
Args:
rngs: PRNG key collection
hidden_size: Size of hidden states
num_attention_heads: Number of attention heads
attention_dropout: Dropout probability for attention weights
attention_bias: Whether to use bias in linear layers
global_rope_theta: Base for global RoPE
max_position_embeddings: Maximum sequence length
local_attention: Tuple of (left, right) window sizes for local attention
local_rope_theta: Base for local RoPE (optional)
layer_id: Layer index for determining attention type
global_attn_every_n_layers: Apply global attention every N layers
"""
if hidden_size % num_attention_heads != 0:
raise ValueError(
f"The hidden size ({hidden_size}) is not a multiple of the number "
f"of attention heads ({num_attention_heads})"
)
# Store configuration
self.hidden_size = hidden_size
self.num_attention_heads = num_attention_heads
self.head_dim = hidden_size // num_attention_heads
self.all_head_size = self.head_dim * num_attention_heads
self.attention_dropout = attention_dropout
# Compute local attention window if needed
if layer_id is not None and layer_id % global_attn_every_n_layers != 0:
local_attention = (local_attention[0] // 2, local_attention[1] // 2)
# Select RoPE parameters based on attention type
rope_theta = global_rope_theta
max_pos = max_position_embeddings
if local_attention != (-1, -1):
if local_rope_theta is not None:
rope_theta = local_rope_theta
max_pos = local_attention[0] + local_attention[1]
# Initialize layers
self.Wqkv = nnx.Linear(
rngs=rngs,
in_features=hidden_size,
out_features=3 * self.all_head_size,
use_bias=attention_bias,
)
self.rotary_emb = RoPEPositionalEmbedding(
rngs=rngs,
dim=self.head_dim,
max_position_embeddings=max_pos,
base=rope_theta,
)
self.Wo = nnx.Linear(
rngs=rngs,
in_features=hidden_size,
out_features=hidden_size,
use_bias=attention_bias,
)
# Store configuration
self.local_attention = local_attention
self._rngs = rngs # Store rngs for dropout
[docs]
def __call__(
self,
hidden_states: jax.Array,
attention_mask: jax.Array | None = None,
sliding_window_mask: jax.Array | None = None,
position_ids: jax.Array | None = None,
deterministic: bool = True,
output_attentions: bool = False,
) -> tuple[jax.Array] | tuple[jax.Array, jax.Array]:
"""Apply attention module.
Args:
hidden_states: Input tensor of shape [batch_size, seq_len, hidden_size]
attention_mask: Optional attention mask
sliding_window_mask: Optional sliding window mask for local attention
position_ids: Optional position ids for RoPE
deterministic: Whether to apply dropout
output_attentions: Whether to return attention probabilities
Returns:
Tuple of:
- Output tensor of shape [batch_size, seq_len, hidden_size]
- Attention probabilities (optional) of shape [b_size, n_heads, seq_len, seq_len]
"""
# Project to Q, K, V
qkv = self.Wqkv(hidden_states) # [batch_size, seq_len, 3 * hidden_size]
# Reshape to [batch_size, seq_len, 3, num_heads, head_dim]
batch_size = hidden_states.shape[0]
qkv = jnp.reshape(qkv, (batch_size, -1, 3, self.num_attention_heads, self.head_dim))
# Split into query, key, value and apply RoPE
# First transpose to get [batch_size, num_heads, seq_len, head_dim]
qkv = jnp.transpose(qkv, (0, 3, 1, 2, 4)) # [batch_size, num_heads, seq_len, 3, head_dim]
query, key, value = jnp.split(
qkv, 3, axis=3
) # Each: [batch_size, num_heads, seq_len, 1, head_dim]
query = jnp.squeeze(query, axis=3) # [batch_size, num_heads, seq_len, head_dim]
key = jnp.squeeze(key, axis=3)
value = jnp.squeeze(value, axis=3)
# Apply rotary embeddings
# First transpose to [batch_size, seq_len, num_heads, head_dim] for RoPE
query = jnp.transpose(query, (0, 2, 1, 3))
key = jnp.transpose(key, (0, 2, 1, 3))
query = apply_rotary_pos_emb(query, self.rotary_emb.cache, position_ids)
key = apply_rotary_pos_emb(key, self.rotary_emb.cache, position_ids)
# Transpose back to [batch_size, num_heads, seq_len, head_dim]
query = jnp.transpose(query, (0, 2, 1, 3))
key = jnp.transpose(key, (0, 2, 1, 3))
# Compute attention scores
scale = jnp.sqrt(self.head_dim)
attention_scores = jnp.matmul(query, jnp.swapaxes(key, -2, -1)) / scale
# Apply attention masks
if self.local_attention != (-1, -1):
# Create sliding window mask if not provided
if sliding_window_mask is None:
sliding_window_mask = create_sliding_window_mask(
query.shape[2], self.local_attention, query.dtype
)
# Ensure the mask is properly broadcasted to [batch_size, num_heads, seq_len, seq_len]
if len(sliding_window_mask.shape) == 4 and sliding_window_mask.shape[:2] == (1, 1):
sliding_window_mask = jnp.broadcast_to(
sliding_window_mask,
(batch_size, self.num_attention_heads, query.shape[2], query.shape[2]),
)
# Apply sliding window mask first
attention_scores = attention_scores + sliding_window_mask
# Apply additional attention mask if provided
if attention_mask is not None and attention_mask is not sliding_window_mask:
# Ensure proper broadcasting for attention mask
if len(attention_mask.shape) == 4 and attention_mask.shape[:2] == (1, 1):
attention_mask = jnp.broadcast_to(
attention_mask,
(batch_size, self.num_attention_heads, query.shape[2], query.shape[2]),
)
attention_scores = attention_scores + attention_mask
# Apply softmax and dropout
attention_probs = jax.nn.softmax(attention_scores, axis=-1)
# Explicitly zero out probabilities where sliding window mask was -inf
if self.local_attention != (-1, -1):
# Create a binary mask from the sliding window mask
binary_mask = jnp.where(sliding_window_mask == -jnp.inf, 0.0, 1.0)
attention_probs = attention_probs * binary_mask
if not deterministic and self.attention_dropout > 0:
attention_probs = nnx.Dropout(
self.attention_dropout,
rngs=self._rngs,
deterministic=deterministic,
)(attention_probs)
# Compute attention output
attention_output = jnp.matmul(attention_probs, value)
# Reshape and apply output projection
attention_output = jnp.transpose(attention_output, (0, 2, 1, 3))
attention_output = jnp.reshape(attention_output, (batch_size, -1, self.hidden_size))
attention_output = self.Wo(attention_output)
# Apply output dropout
if not deterministic and self.attention_dropout > 0:
attention_output = nnx.Dropout(
self.attention_dropout,
rngs=self._rngs,
deterministic=deterministic,
)(attention_output)
if output_attentions:
return (attention_output, attention_probs)
return (attention_output,)
[docs]
class ModernBertLayer(nnx.Module):
"""ModernBERT transformer layer with pre-LayerNorm architecture.
This implements a transformer layer with:
1. Pre-LayerNorm for attention and MLP
2. Residual connections
3. Optional identity for first layer's attention norm
"""
[docs]
def __init__(
self,
rngs: nnx.Rngs,
hidden_size: int,
num_attention_heads: int,
intermediate_size: int,
layer_id: int | None = None,
attention_dropout: float = 0.0,
hidden_dropout: float = 0.0,
attention_bias: bool = True,
norm_eps: float = 1e-12,
norm_bias: bool = True,
global_rope_theta: float = 10000.0,
max_position_embeddings: int = 4096,
local_attention: tuple[int, int] = (-1, -1),
local_rope_theta: float | None = None,
global_attn_every_n_layers: int = 4,
):
"""Initialize transformer layer.
Args:
rngs: PRNG key collection
hidden_size: Size of hidden states
num_attention_heads: Number of attention heads
intermediate_size: Size of MLP intermediate layer
layer_id: Layer index (first layer uses identity for attn norm)
attention_dropout: Dropout probability for attention
hidden_dropout: Dropout probability for hidden states
attention_bias: Whether to use bias in attention
norm_eps: Epsilon for layer normalization
norm_bias: Whether to use bias in layer normalization
global_rope_theta: Base for global RoPE
max_position_embeddings: Maximum sequence length
local_attention: Tuple of (left, right) window sizes
local_rope_theta: Base for local RoPE (optional)
global_attn_every_n_layers: Apply global attention every N layers
"""
super().__init__()
# Initialize attention normalization
self.attn_norm = (
Identity()
if layer_id == 0
else nnx.LayerNorm(
rngs=rngs,
num_features=hidden_size,
epsilon=norm_eps,
use_bias=norm_bias,
reduction_axes=(-1,),
feature_axes=(-1,),
)
) # type: ignore
# Initialize attention
self.attn = ModernBertAttention(
rngs=rngs,
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
attention_dropout=attention_dropout,
attention_bias=attention_bias,
global_rope_theta=global_rope_theta,
max_position_embeddings=max_position_embeddings,
local_attention=local_attention,
local_rope_theta=local_rope_theta,
layer_id=layer_id,
global_attn_every_n_layers=global_attn_every_n_layers,
)
# Initialize MLP normalization
self.mlp_norm = nnx.LayerNorm(
rngs=rngs,
num_features=hidden_size,
epsilon=norm_eps,
use_bias=norm_bias,
reduction_axes=(-1,),
feature_axes=(-1,),
)
# Initialize MLP
self.mlp = ModernBertMLP(
rngs=rngs,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
mlp_dropout=hidden_dropout,
)
[docs]
def __call__(
self,
hidden_states: jax.Array,
attention_mask: jax.Array | None = None,
sliding_window_mask: jax.Array | None = None,
position_ids: jax.Array | None = None,
deterministic: bool = True,
output_attentions: bool = False,
) -> tuple[jax.Array] | tuple[jax.Array, jax.Array]:
"""Apply transformer layer.
Args:
hidden_states: Input tensor of shape [batch_size, seq_len, hidden_size]
attention_mask: Optional attention mask
sliding_window_mask: Optional sliding window mask
position_ids: Optional position ids for RoPE
deterministic: Whether to apply dropout
output_attentions: Whether to return attention probabilities
Returns:
Tuple of:
- Output tensor of shape [batch_size, seq_len, hidden_size]
- Attention probabilities (optional) of shape [b_size, n_heads, seq_len, seq_len]
"""
# Apply attention with pre-norm and residual
attn_outputs = self.attn(
self.attn_norm(hidden_states),
attention_mask=attention_mask,
sliding_window_mask=sliding_window_mask,
position_ids=position_ids,
deterministic=deterministic,
output_attentions=output_attentions,
)
attention_output = attn_outputs[0]
attention_weights = attn_outputs[1] if output_attentions else None # type: ignore
# Residual connection
hidden_states = hidden_states + attention_output
# Apply MLP with pre-norm and residual
mlp_output = self.mlp(self.mlp_norm(hidden_states), deterministic=deterministic)
hidden_states = hidden_states + mlp_output
if output_attentions:
return (hidden_states, attention_weights) # type: ignore
return (hidden_states,)
[docs]
class ModernBERTEncoder(nnx.Module):
"""ModernBERT encoder consisting of multiple transformer layers."""
[docs]
def __init__(
self,
rngs: nnx.Rngs,
hidden_size: int,
num_attention_heads: int,
intermediate_size: int,
num_hidden_layers: int,
attention_dropout: float = 0.0,
hidden_dropout: float = 0.0,
attention_bias: bool = True,
norm_eps: float = 1e-12,
norm_bias: bool = True,
global_rope_theta: float = 10000.0,
max_position_embeddings: int = 4096,
local_attention: tuple[int, int] = (-1, -1),
local_rope_theta: float | None = None,
global_attn_every_n_layers: int = 4,
):
"""Initialize encoder.
Args:
rngs: PRNG key collection
hidden_size: Size of hidden states
num_attention_heads: Number of attention heads
intermediate_size: Size of MLP intermediate layer
num_hidden_layers: Number of transformer layers
attention_dropout: Dropout probability for attention
hidden_dropout: Dropout probability for hidden states
attention_bias: Whether to use bias in attention
norm_eps: Epsilon for layer normalization
norm_bias: Whether to use bias in layer normalization
global_rope_theta: Base for global RoPE
max_position_embeddings: Maximum sequence length
local_attention: Tuple of (left, right) window sizes
local_rope_theta: Base for local RoPE (optional)
global_attn_every_n_layers: Apply global attention every N layers
"""
super().__init__()
# Initialize transformer layers
self.layers = [
ModernBertLayer(
rngs=rngs,
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
intermediate_size=intermediate_size,
layer_id=layer_id,
attention_dropout=attention_dropout,
hidden_dropout=hidden_dropout,
attention_bias=attention_bias,
norm_eps=norm_eps,
norm_bias=norm_bias,
global_rope_theta=global_rope_theta,
max_position_embeddings=max_position_embeddings,
local_attention=local_attention,
local_rope_theta=local_rope_theta,
global_attn_every_n_layers=global_attn_every_n_layers,
)
for layer_id in range(num_hidden_layers)
]
# Initialize final layer normalization
self.final_norm = nnx.LayerNorm(
rngs=rngs,
num_features=hidden_size,
epsilon=norm_eps,
use_bias=norm_bias,
reduction_axes=(-1,),
feature_axes=(-1,),
)
[docs]
def __call__(
self,
hidden_states: jnp.ndarray,
attention_mask: jnp.ndarray | None = None,
sliding_window_mask: jnp.ndarray | None = None,
position_ids: jnp.ndarray | None = None,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
) -> (
tuple[jnp.ndarray]
| tuple[jnp.ndarray, list[jnp.ndarray]]
| tuple[jnp.ndarray, list[jnp.ndarray], list[jnp.ndarray]]
):
"""Apply transformer encoder.
Args:
hidden_states: Input tensor
attention_mask: Optional attention mask
sliding_window_mask: Optional sliding window mask
position_ids: Optional position ids
deterministic: Whether to apply dropout
output_attentions: Whether to return attention weights
output_hidden_states: Whether to return all hidden states
Returns:
Tuple of:
- Output tensor
- All hidden states (optional, if output_hidden_states=True)
- All attention weights (optional, if output_attentions=True)
"""
all_hidden_states: tuple | None = () if output_hidden_states else None # type: ignore
all_self_attentions: tuple | None = () if output_attentions else None
# Process through each layer
for layer in self.layers:
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) # type: ignore #noqa: RUF005
layer_outputs = layer(
hidden_states,
attention_mask=attention_mask,
sliding_window_mask=sliding_window_mask,
position_ids=position_ids,
deterministic=deterministic,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]
if output_attentions and len(layer_outputs) > 1 and all_self_attentions is not None:
all_self_attentions = all_self_attentions + (layer_outputs[1],) # noqa: RUF005
# Add final hidden state if requested
if output_hidden_states and all_hidden_states is not None:
all_hidden_states = all_hidden_states + (hidden_states,) # noqa: RUF005
# Apply final layer normalization
hidden_states = self.final_norm(hidden_states)
if not output_hidden_states and not output_attentions:
return (hidden_states,)
elif output_hidden_states and not output_attentions:
return (hidden_states, all_hidden_states) # type: ignore
elif output_attentions and not output_hidden_states:
return (hidden_states, all_self_attentions) # type: ignore
else: # both output_hidden_states and output_attentions
return (hidden_states, all_hidden_states, all_self_attentions) # type: ignore
class ModernBERTMLMHead(nnx.Module):
"""ModernBERT masked language modeling head."""
def __init__(
self,
rngs: nnx.Rngs,
hidden_size: int,
vocab_size: int,
norm_eps: float = 1e-12,
norm_bias: bool = True,
):
"""Initialize MLM head.
Args:
rngs: PRNG key collection
hidden_size: Size of hidden states
vocab_size: Size of vocabulary
norm_eps: Epsilon for layer normalization
norm_bias: Whether to use bias in layer normalization
"""
super().__init__()
# Layer norm
self.norm = nnx.LayerNorm(
rngs=rngs,
num_features=hidden_size,
epsilon=norm_eps,
use_bias=norm_bias,
reduction_axes=(-1,),
feature_axes=(-1,),
)
# Dense projection
self.dense = nnx.Linear(
rngs=rngs,
in_features=hidden_size,
out_features=hidden_size,
use_bias=True,
)
# Output projection (tied with embeddings)
self.decoder = nnx.Linear(
rngs=rngs,
in_features=hidden_size,
out_features=vocab_size,
use_bias=True,
)
def __call__(self, hidden_states: jnp.ndarray) -> jnp.ndarray:
"""Apply MLM head.
Args:
hidden_states: Input tensor of shape [batch_size, seq_len, hidden_size]
Returns:
Logits of shape [batch_size, seq_len, vocab_size]
"""
hidden_states = self.norm(hidden_states)
hidden_states = self.dense(hidden_states)
hidden_states = jax.nn.gelu(hidden_states)
hidden_states = self.decoder(hidden_states)
return hidden_states
@dataclass
class ModernBERTConfig(BaseConfig):
"""
Configuration for ModernBERT model.
This configuration class extends BaseConfig and contains all the parameters
required to initialize a ModernBERT model. It includes settings for model architecture,
attention mechanisms, dropout rates, and other hyperparameters.
Attributes:
vocab_size: Size of vocabulary
hidden_size: Size of hidden states
num_hidden_layers: Number of transformer layers
num_attention_heads: Number of attention heads
intermediate_size: Size of MLP intermediate layer
max_position_embeddings: Maximum sequence length
attention_dropout: Dropout probability for attention
hidden_dropout: Dropout probability for hidden states
attention_bias: Whether to use bias in attention
norm_eps: Epsilon for layer normalization
norm_bias: Whether to use bias in layer normalization
global_rope_theta: Base for global RoPE
local_attention: Tuple of (left, right) window sizes
local_rope_theta: Base for local RoPE (optional)
global_attn_every_n_layers: Apply global attention every N layers
pad_token_id: Token ID to use for padding
"""
# Default ModernBERT configuration
# https://huggingface.co/docs/transformers/main//model_doc/modernbert#transformers.ModernBertConfig
vocab_size: int = 50368
hidden_size: int = 768
intermediate_size: int = 1152
num_hidden_layers: int = 22
num_attention_heads: int = 12
max_position_embeddings: int = 8192
attention_dropout: float = 0.0
hidden_dropout: float = 0.0
attention_bias: bool = False
norm_eps: float = 1e-05
norm_bias: bool = False
global_rope_theta: float = 160000.0
local_attention: tuple[int, int] = (-1, -1)
local_rope_theta: float | None = None
global_attn_every_n_layers: int = 3
pad_token_id: int = 50283
[docs]
class ModernBertEmbeddings(nnx.Module):
"""Token embeddings with normalization and dropout.
Similar to BERT embeddings but without position embeddings since we use RoPE.
"""
[docs]
def __init__(
self,
rngs: nnx.Rngs,
vocab_size: int,
hidden_size: int,
pad_token_id: int = 0,
norm_eps: float = 1e-12,
norm_bias: bool = True,
embedding_dropout: float = 0.0,
):
"""Initialize embeddings module.
Args:
rngs: PRNG key collection
vocab_size: Size of the vocabulary
hidden_size: Size of the embeddings
pad_token_id: Token ID to use for padding
norm_eps: Epsilon for layer normalization
norm_bias: Whether to use bias in layer normalization
embedding_dropout: Dropout probability for embeddings
"""
super().__init__()
# Create embeddings table with non-zero initialization
self.token_embeddings = nnx.Embed(
rngs=rngs,
num_embeddings=vocab_size,
features=hidden_size,
embedding_init=nnx.initializers.normal(stddev=1.0), # Increased stddev
)
# Layer norm with explicit feature axes
self.norm = nnx.LayerNorm(
rngs=rngs,
num_features=hidden_size,
epsilon=norm_eps,
use_bias=norm_bias,
use_scale=True,
scale_init=nnx.initializers.ones,
bias_init=nnx.initializers.zeros,
reduction_axes=(-1,), # Explicit tuple
feature_axes=(-1,), # Explicit tuple
)
self.dropout = embedding_dropout
self.deterministic = True # Will be overridden in __call__
self._rngs = rngs # Store rngs for dropout
[docs]
def __call__(
self,
input_ids: jnp.ndarray,
deterministic: bool = True,
inputs_embeds: jnp.ndarray | None = None,
) -> jnp.ndarray:
"""Apply embeddings module.
Args:
input_ids: Integer tokens of shape [batch_size, seq_len]
deterministic: Whether to apply dropout
inputs_embeds: Optional pre-computed embeddings
Returns:
Embedded tokens with shape [batch_size, seq_len, hidden_size]
"""
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.token_embeddings(input_ids)
# Scale embeddings
hidden_states = hidden_states * (hidden_states.shape[-1] ** 0.5)
# Apply layer norm
hidden_states = self.norm(hidden_states)
# Apply dropout if in training
if not deterministic and self.dropout > 0.0:
hidden_states = nnx.Dropout(
rngs=self._rngs,
rate=self.dropout,
)(hidden_states, deterministic=deterministic)
return hidden_states
[docs]
class ModernBertMLP(nnx.Module):
"""MLP with gated linear units.
Replaces the traditional intermediate + output layers with a single gated MLP.
"""
[docs]
def __init__(
self,
rngs: nnx.Rngs,
hidden_size: int,
intermediate_size: int,
mlp_bias: bool = True,
mlp_dropout: float = 0.0,
):
"""Initialize MLP module.
Args:
rngs: PRNG key collection
hidden_size: Size of input and output
intermediate_size: Size of intermediate layer
mlp_bias: Whether to use bias in linear layers
mlp_dropout: Dropout probability
"""
super().__init__()
# Input projection (creates both input and gate values)
self.Wi = nnx.Linear(
rngs=rngs,
in_features=hidden_size,
out_features=intermediate_size * 2,
use_bias=mlp_bias,
kernel_init=nnx.initializers.normal(stddev=0.02),
)
# Output projection
self.Wo = nnx.Linear(
rngs=rngs,
in_features=intermediate_size,
out_features=hidden_size,
use_bias=mlp_bias,
kernel_init=nnx.initializers.normal(stddev=0.02),
)
self.dropout = mlp_dropout
self.deterministic = True # Will be overridden in __call__
self._rngs = rngs # Store rngs for dropout
[docs]
def __call__(self, hidden_states: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray:
"""Apply MLP module.
Args:
hidden_states: Input tensor of shape [batch_size, seq_len, hidden_size]
deterministic: Whether to apply dropout
Returns:
Output tensor of shape [batch_size, seq_len, hidden_size]
"""
# Project to intermediate size
hidden_states = self.Wi(hidden_states)
# Apply GELU activation
hidden_states = jax.nn.gelu(hidden_states)
# Split into input and gate by taking every other feature
hidden_states = self.Wo(hidden_states[..., ::2] * hidden_states[..., 1::2])
# Apply dropout if in training
if not deterministic and self.dropout > 0.0:
hidden_states = nnx.Dropout(
rngs=self._rngs,
rate=self.dropout,
)(hidden_states, deterministic=deterministic)
return hidden_states