"""LLama model implementation in JAX using Flax NNX.
This module implements the LLama architecture as described in Meta's LLama series of models.
The implementation includes modern transformer architecture with rotary position embeddings (RoPE),
group query attention (GQA), and SwiGLU activation function in the feed-forward network.
See: https://arxiv.org/abs/2302.13971
"""
from collections.abc import Iterator
from dataclasses import dataclass
from typing import Any
import jax
import jax.numpy as jnp
from flax import nnx
from jaxgarden.models.base import BaseConfig, BaseModel
from jaxgarden.models.generation_utils import GenerationMixin
[docs]
@dataclass
class LlamaConfig(BaseConfig):
"""
Configuration for LLama model.
This configuration class extends BaseConfig and contains all the parameters
required to initialize a LLama model. It includes settings for model architecture,
attention mechanisms, and other hyperparameters.
Attributes:
dim: Size of hidden states
n_layers: Number of transformer layers
n_heads: Number of attention heads
n_kv_heads: Number of key/value heads (for group query attention)
head_dim: Dimension of each attention head
intermediate_size: Size of MLP intermediate layer
vocab_size: Size of vocabulary
multiple_of: Ensure dimensions are multiples of this value
norm_eps: Epsilon for layer normalization
rope_theta: Base for rotary position embeddings
"""
dim: int = 2048
n_layers: int = 16
n_heads: int = 32
n_kv_heads: int = 8
head_dim: int = 64
intermediate_size: int = 14336
vocab_size: int = 128256
multiple_of: int = 256
norm_eps: float = 1e-05
rope_theta: float = 500000.0
[docs]
class LlamaRMSNorm(nnx.Module):
"""Root Mean Square Layer Normalization.
This implementation follows the RMSNorm paper: https://arxiv.org/abs/1910.07467
Instead of using mean and variance like traditional LayerNorm, RMSNorm only uses
the root mean square of the inputs for normalization.
Attributes:
norm_weights: The learned scale parameters
norm_eps: Small constant for numerical stability
"""
[docs]
def __init__(self, dim: int, *, norm_eps: float = 1e-05, rngs: nnx.Rngs):
"""Initialize RMSNorm module.
Args:
dim: Dimension of the input tensor
norm_eps: Small constant for numerical stability
rngs: PRNG key collection
"""
super().__init__()
self.norm_weights = nnx.Param(jnp.zeros((dim,), dtype=jnp.bfloat16))
self.norm_eps = norm_eps
[docs]
@nnx.jit()
def __call__(self, hidden_states: jnp.ndarray) -> jnp.ndarray:
"""Apply RMS normalization to input tensor.
Args:
hidden_states: Input tensor of shape [..., dim]
Returns:
Normalized tensor with same shape as input
"""
input_dtype = hidden_states.dtype
hidden_states = hidden_states.astype(jnp.float32)
squared_mean = jnp.mean(jnp.square(hidden_states), axis=-1, keepdims=True)
hidden_states = hidden_states * jnp.reciprocal(jnp.sqrt(squared_mean + self.norm_eps))
return self.norm_weights * hidden_states.astype(input_dtype)
[docs]
class LlamaRotaryEmbedding(nnx.Module):
"""Rotary Position Embedding (RoPE) implementation for LLama.
Based on: https://arxiv.org/abs/2104.09864
Attributes:
dim: Dimension of the embeddings
base: Base for the sinusoidal functions
"""
[docs]
def __init__(self, dim: int, *, base: float = 10000.0, rngs: nnx.Rngs):
"""Initialize RoPE module.
Args:
dim: Dimension of the embeddings (must be even)
base: Base for the sinusoidal functions
rngs: PRNG key collection
"""
super().__init__()
self.dim = dim
self.base = base
[docs]
@nnx.jit()
def __call__(self, position_ids: jnp.ndarray) -> tuple[jnp.ndarray, jnp.ndarray]:
"""Generate rotary embeddings from position ids.
Args:
position_ids: Position indices of shape [batch_size, seq_len]
Returns:
Tuple of (cos, sin) tensors for rotary embeddings
"""
inv_freq = 1.0 / (self.base ** (jnp.arange(0, self.dim, 2, dtype=jnp.float32) / self.dim))
inv_freq_expanded = jnp.expand_dims(inv_freq, axis=(0, 1))
position_ids_expanded = jnp.expand_dims(position_ids, axis=(0, 2)).astype(jnp.float32)
freqs = jnp.einsum("bij,bjk->bijk", position_ids_expanded, inv_freq_expanded)
emb = jnp.concatenate([freqs, freqs], axis=-1)
cos = jnp.cos(emb).squeeze(2).astype(jnp.bfloat16)
sin = jnp.sin(emb).squeeze(2).astype(jnp.bfloat16)
return cos, sin
[docs]
class LlamaAttention(nnx.Module):
"""Multi-headed attention with support for Group Query Attention (GQA).
This implements the LLama attention mechanism with rotary position embeddings (RoPE)
and support for fewer key-value heads than query heads (GQA).
Attributes:
q_proj: Linear projection for queries
k_proj: Linear projection for keys
v_proj: Linear projection for values
o_proj: Linear projection for output
rotary_emb: Rotary position embeddings
head_dim: Dimension of each attention head
n_heads: Number of attention heads
n_kv_heads: Number of key/value heads
"""
[docs]
def __init__(
self,
layer_idx: int,
dim: int,
n_heads: int,
n_kv_heads: int,
head_dim: int,
rope_theta: float,
*,
rngs: nnx.Rngs,
):
"""Initialize attention module.
Args:
layer_idx: Index of the layer
dim: Size of hidden states
n_heads: Number of attention heads
n_kv_heads: Number of key/value heads
head_dim: Dimension of each attention head
rope_theta: Base for rotary position embeddings
rngs: PRNG key collection
"""
self.q_proj = nnx.Linear(
dim, n_heads * head_dim, use_bias=False, rngs=rngs, param_dtype=jnp.bfloat16
)
self.k_proj = nnx.Linear(
dim, n_kv_heads * head_dim, use_bias=False, rngs=rngs, param_dtype=jnp.bfloat16
)
self.v_proj = nnx.Linear(
dim, n_kv_heads * head_dim, use_bias=False, rngs=rngs, param_dtype=jnp.bfloat16
)
self.o_proj = nnx.Linear(
n_heads * head_dim, dim, use_bias=False, rngs=rngs, param_dtype=jnp.bfloat16
)
self.rotary_emb = LlamaRotaryEmbedding(head_dim, base=rope_theta, rngs=rngs)
self.head_dim = head_dim
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads
[docs]
def apply_rotary_pos_emb(
self,
q: jnp.ndarray,
k: jnp.ndarray,
cos: jnp.ndarray,
sin: jnp.ndarray,
unsqueeze_dim: int = 1,
) -> tuple[jnp.ndarray, jnp.ndarray]:
"""Apply rotary position embeddings to query and key tensors.
Args:
q: Query tensor
k: Key tensor
cos: Cosine component of rotary embeddings
sin: Sine component of rotary embeddings
unsqueeze_dim: Dimension to unsqueeze cosine and sine components
Returns:
Tuple of (rotated_q, rotated_k)
"""
cos = jnp.expand_dims(cos, axis=unsqueeze_dim)
sin = jnp.expand_dims(sin, axis=unsqueeze_dim)
q_embed = (q * cos) + (self.rotate_half(q) * sin)
k_embed = (k * cos) + (self.rotate_half(k) * sin)
return q_embed, k_embed
[docs]
def rotate_half(self, x: jnp.ndarray) -> jnp.ndarray:
"""Rotate half the hidden dims of the input.
Args:
x: Input tensor
Returns:
Tensor with half the dimensions rotated
"""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return jnp.concatenate([-x2, x1], axis=-1)
[docs]
def repeat_kv(self, hidden_states: jnp.ndarray, n_repeat: int) -> jnp.ndarray:
"""Repeat key/value heads to match the number of query heads.
When using GQA, we need to repeat each key/value head to match
the number of query heads.
Args:
hidden_states: Key or value tensor of shape [batch, n_kv_heads, seq_len, head_dim]
n_repeat: Number of times to repeat each key/value head
Returns:
Tensor with repeated key/value heads
"""
batch, n_kv_heads, seq_len, head_dim = hidden_states.shape
if n_repeat == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].repeat(n_repeat, axis=2)
return hidden_states.reshape(batch, n_kv_heads * n_repeat, seq_len, head_dim)
[docs]
@nnx.jit()
def __call__(
self, x: jnp.ndarray, position_ids: jnp.ndarray, attention_mask: jnp.ndarray
) -> jnp.ndarray:
"""Apply self-attention using queries, keys, and values derived from input x.
Args:
x: Input tensor of shape [batch_size, seq_len, dim]
position_ids: Position indices of shape [batch_size, seq_len]
attention_mask: Additive attention mask of shape (batch_size, 1, q_len, kv_len)
or a shape that can be broadcasted to this shape.
Returns:
Output tensor of shape [batch_size, seq_len, dim]
"""
batch_size, seq_len, _ = x.shape
query = (
self.q_proj(x)
.reshape(batch_size, seq_len, self.n_heads, self.head_dim)
.transpose((0, 2, 1, 3))
)
key = (
self.k_proj(x)
.reshape(batch_size, seq_len, self.n_kv_heads, self.head_dim)
.transpose((0, 2, 1, 3))
)
value = (
self.v_proj(x)
.reshape(batch_size, seq_len, self.n_kv_heads, self.head_dim)
.transpose((0, 2, 1, 3))
)
# Assuming batch_size=1
cos, sin = self.rotary_emb(position_ids[0])
query, key = self.apply_rotary_pos_emb(query, key, cos, sin)
key = self.repeat_kv(key, self.n_heads // self.n_kv_heads)
value = self.repeat_kv(value, self.n_heads // self.n_kv_heads)
attn_weights = jnp.matmul(query, jnp.transpose(key, (0, 1, 3, 2)))
attn_weights = (attn_weights.astype(jnp.float32) / jnp.sqrt(self.head_dim)).astype(
jnp.bfloat16
)
attn_weights = attn_weights + attention_mask
attn_weights = jax.nn.softmax(attn_weights.astype(jnp.float32), axis=-1).astype(
jnp.bfloat16
)
attn_output = (
jnp.matmul(attn_weights, value).transpose((0, 2, 1, 3)).reshape(batch_size, seq_len, -1)
)
output = self.o_proj(attn_output)
return output
[docs]
class LlamaMLP(nnx.Module):
"""LLama's MLP implementation with SwiGLU activation.
This implements the SwiGLU MLP used in LLama models:
down_proj(silu(gate_proj(x)) * up_proj(x))
Attributes:
gate_proj: Linear projection for gate
up_proj: Linear projection for up-projection
down_proj: Linear projection for down-projection
"""
[docs]
def __init__(self, layer_idx: int, dim: int, intermediate_size: int, *, rngs: nnx.Rngs):
"""Initialize MLP module.
Args:
layer_idx: Index of the layer
dim: Size of hidden states
intermediate_size: Size of intermediate layer
rngs: PRNG key collection
"""
self.gate_proj = nnx.Linear(
dim, intermediate_size, use_bias=False, rngs=rngs, param_dtype=jnp.bfloat16
)
self.up_proj = nnx.Linear(
dim, intermediate_size, use_bias=False, rngs=rngs, param_dtype=jnp.bfloat16
)
self.down_proj = nnx.Linear(
intermediate_size, dim, use_bias=False, rngs=rngs, param_dtype=jnp.bfloat16
)
[docs]
@nnx.jit()
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
"""Apply SwiGLU MLP to input tensor.
Args:
x: Input tensor of shape [batch_size, seq_len, dim]
Returns:
Output tensor of shape [batch_size, seq_len, dim]
"""
return self.down_proj(jax.nn.silu(self.gate_proj(x)) * self.up_proj(x))
[docs]
class LlamaForCausalLM(BaseModel, GenerationMixin):
"""LLama model for causal language modeling.
This implements the full LLama model for generating text.
It consists of token embeddings, transformer layers, and language modeling head.
Attributes:
token_embed: Token embedding layer
layers: List of transformer blocks
lm_head: Linear layer for language modeling
norm: Final layer normalization
"""
[docs]
def __init__(
self,
config: LlamaConfig,
*,
dtype: jnp.dtype | None = None,
param_dtype: jnp.dtype = jnp.float32,
precision: jax.lax.Precision | str | None = None,
rngs: nnx.Rngs,
):
"""Initialize LlamaForCausalLM.
Args:
config: Model configuration
dtype: Data type for computation
param_dtype: Data type for parameters
precision: Precision for matrix multiplication
rngs: PRNG key collection
"""
super().__init__(
config, dtype=dtype, param_dtype=param_dtype, precision=precision, rngs=rngs
)
self.token_embed = nnx.Embed(
num_embeddings=config.vocab_size,
features=config.dim,
param_dtype=param_dtype,
rngs=rngs,
)
self.layers = [
LlamaTransformerBlock(
layer_idx=idx,
dim=config.dim,
n_heads=config.n_heads,
n_kv_heads=config.n_kv_heads,
head_dim=config.head_dim,
rope_theta=config.rope_theta,
intermediate_size=config.intermediate_size,
norm_eps=config.norm_eps,
rngs=rngs,
)
for idx in range(config.n_layers)
]
self.lm_head = nnx.Linear(
config.dim, config.vocab_size, use_bias=False, rngs=rngs, param_dtype=param_dtype
)
self.norm = LlamaRMSNorm(dim=config.dim, rngs=rngs)
[docs]
@nnx.jit()
def __call__(
self,
input_ids: jnp.ndarray,
attention_mask: jnp.ndarray | None = None,
deterministic: bool = True,
) -> jnp.ndarray:
"""Forward pass of the LLama model.
Args:
input_ids: Input token ids of shape [batch_size, seq_len]
attention_mask: Boolean attention mask of shape [batch_size, seq_len]
Returns:
Logits for next token prediction of shape [batch_size, seq_len, vocab_size]
"""
assert input_ids.shape[0] == 1, "Only batch size 1 is supported"
print(input_ids.shape)
position_ids = jnp.arange(input_ids.shape[-1])[None, :].astype(jnp.int32)
attention_mask = jnp.where(attention_mask, 0.0, -jnp.inf)[None, None, ...]
x = self.token_embed(input_ids)
for layer in self.layers:
x = layer(x, position_ids, attention_mask)
x = self.norm(x)
logits = self.lm_head(x)
return logits
[docs]
def convert_weights_from_hf(
self, state: nnx.State | dict[str, jnp.ndarray], weights: Iterator[tuple[Any, Any]]
) -> None:
for wholekey, tensor in weights:
keys = wholekey.split(".")
if keys[1] == "layers":
if keys[3] == "self_attn":
keys[3] = "attention"
if (keys[1] == "layers" and keys[3] == "attention") or (
keys[1] == "layers" and keys[3] == "mlp"
):
state["layers"][int(keys[2])][keys[3]][keys[4]]["kernel"].value = tensor.T # type: ignore
elif (keys[1] == "layers" and keys[3] == "input_layernorm") or (
keys[1] == "layers" and keys[3] == "post_attention_layernorm"
):
state["layers"][int(keys[2])][keys[3]]["norm_weights"].value = tensor # type: ignore
elif keys[1] == "embed_tokens":
state["token_embed"].embedding.value = tensor # type: ignore
state["lm_head"].kernel.value = tensor.T # type: ignore
elif keys[1] == "norm":
state["norm"].norm_weights.value = tensor # type: ignore