JAXgarden
Contents:
API Documentation
JAXgarden
Index
Index
_
|
A
|
B
|
C
|
D
|
E
|
F
|
G
|
H
|
I
|
J
|
K
|
L
|
M
|
N
|
O
|
P
|
Q
|
R
|
S
|
T
|
U
|
V
_
__call__() (jaxgarden.models.Gemma2RotaryEmbedding method)
(jaxgarden.models.Gemma3RotaryEmbedding method)
(jaxgarden.models.LlamaAttention method)
(jaxgarden.models.LlamaForCausalLM method)
(jaxgarden.models.LlamaMLP method)
(jaxgarden.models.LlamaRMSNorm method)
(jaxgarden.models.LlamaRotaryEmbedding method)
(jaxgarden.models.LlamaTransformerBlock method)
(jaxgarden.models.ModernBertAttention method)
(jaxgarden.models.ModernBertEmbeddings method)
(jaxgarden.models.ModernBERTEncoder method)
(jaxgarden.models.ModernBERTForMaskedLM method)
(jaxgarden.models.ModernBertLayer method)
(jaxgarden.models.ModernBertMLP method)
__init__() (jaxgarden.attention.MultiHeadAttention method)
(jaxgarden.models.BaseModel method)
(jaxgarden.models.LlamaAttention method)
(jaxgarden.models.LlamaForCausalLM method)
(jaxgarden.models.LlamaMLP method)
(jaxgarden.models.LlamaRMSNorm method)
(jaxgarden.models.LlamaRotaryEmbedding method)
(jaxgarden.models.LlamaTransformerBlock method)
(jaxgarden.models.ModernBertAttention method)
(jaxgarden.models.ModernBertEmbeddings method)
(jaxgarden.models.ModernBERTEncoder method)
(jaxgarden.models.ModernBERTForMaskedLM method)
(jaxgarden.models.ModernBertLayer method)
(jaxgarden.models.ModernBertMLP method)
A
apply_rotary_pos_emb() (jaxgarden.models.LlamaAttention method)
apply_soft_cap() (jaxgarden.models.Gemma2Attention static method)
(jaxgarden.models.Gemma3Attention static method)
attention (jaxgarden.models.LlamaTransformerBlock attribute)
attention_bias (jaxgarden.models.Gemma3Config attribute)
,
[1]
attention_dropout (jaxgarden.models.Gemma3Config attribute)
,
[1]
attn_logit_soft_cap (jaxgarden.models.Gemma3Config attribute)
,
[1]
attn_logits_soft_cap (jaxgarden.models.Gemma2Config attribute)
B
base (jaxgarden.models.LlamaRotaryEmbedding attribute)
BaseConfig (class in jaxgarden.models)
BaseModel (class in jaxgarden.models)
bos_token_id (jaxgarden.models.Gemma3Config attribute)
,
[1]
C
config (jaxgarden.models.Gemma2ForCausalLM attribute)
(jaxgarden.models.Gemma3ForCausalLM attribute)
context_length (jaxgarden.models.Gemma2Config attribute)
convert_weights_from_hf() (jaxgarden.models.BaseModel method)
(jaxgarden.models.LlamaForCausalLM method)
D
dim (jaxgarden.models.LlamaConfig attribute)
,
[1]
(jaxgarden.models.LlamaRotaryEmbedding attribute)
dot_product_attention() (in module jaxgarden.functional)
down_proj (jaxgarden.models.LlamaMLP attribute)
download_from_hf() (jaxgarden.models.BaseModel static method)
dtype (jaxgarden.models.Gemma2Config attribute)
(jaxgarden.models.Gemma3Config attribute)
E
eos_token_id (jaxgarden.models.Gemma2Config attribute)
(jaxgarden.models.Gemma3Config attribute)
,
[1]
extra (jaxgarden.models.BaseConfig attribute)
F
final_logit_soft_cap (jaxgarden.models.Gemma2Config attribute)
(jaxgarden.models.Gemma3Config attribute)
,
[1]
from_hf() (jaxgarden.models.BaseModel method)
G
gate_proj (jaxgarden.models.LlamaMLP attribute)
geglu() (jaxgarden.models.Gemma2MLP static method)
Gemma2Attention (class in jaxgarden.models)
Gemma2Config (class in jaxgarden.models)
Gemma2ForCausalLM (class in jaxgarden.models)
Gemma2MLP (class in jaxgarden.models)
Gemma2RMSNorm (class in jaxgarden.models)
Gemma2RotaryEmbedding (class in jaxgarden.models)
Gemma3Attention (class in jaxgarden.models)
Gemma3Config (class in jaxgarden.models)
Gemma3ForCausalLM (class in jaxgarden.models)
Gemma3MLP (class in jaxgarden.models)
Gemma3RMSNorm (class in jaxgarden.models)
Gemma3RotaryEmbedding (class in jaxgarden.models)
generate() (jaxgarden.models.GenerationMixin method)
GenerationMixin (class in jaxgarden.models)
H
head_dim (jaxgarden.models.Gemma2Config attribute)
(jaxgarden.models.Gemma3Config attribute)
,
[1]
(jaxgarden.models.LlamaAttention attribute)
(jaxgarden.models.LlamaConfig attribute)
,
[1]
hidden_activation (jaxgarden.models.Gemma3Config attribute)
,
[1]
hidden_size (jaxgarden.models.Gemma2Config attribute)
(jaxgarden.models.Gemma3Config attribute)
,
[1]
I
initializer_range (jaxgarden.models.Gemma3Config attribute)
,
[1]
input_layernorm (jaxgarden.models.LlamaTransformerBlock attribute)
intermediate_size (jaxgarden.models.Gemma2Config attribute)
(jaxgarden.models.Gemma3Config attribute)
,
[1]
(jaxgarden.models.LlamaConfig attribute)
,
[1]
iter_safetensors() (jaxgarden.models.BaseModel static method)
J
jaxgarden.attention
module
jaxgarden.functional
module
jaxgarden.models
module
K
k_proj (jaxgarden.models.LlamaAttention attribute)
L
layers (jaxgarden.models.LlamaForCausalLM attribute)
LlamaAttention (class in jaxgarden.models)
LlamaConfig (class in jaxgarden.models)
LlamaForCausalLM (class in jaxgarden.models)
LlamaMLP (class in jaxgarden.models)
LlamaRMSNorm (class in jaxgarden.models)
LlamaRotaryEmbedding (class in jaxgarden.models)
LlamaTransformerBlock (class in jaxgarden.models)
lm_head (jaxgarden.models.LlamaForCausalLM attribute)
load() (jaxgarden.models.BaseModel method)
log_level (jaxgarden.models.BaseConfig attribute)
M
max_position_embeddings (jaxgarden.models.Gemma3Config attribute)
,
[1]
mlp (jaxgarden.models.LlamaTransformerBlock attribute)
ModernBertAttention (class in jaxgarden.models)
ModernBertEmbeddings (class in jaxgarden.models)
ModernBERTEncoder (class in jaxgarden.models)
ModernBERTForMaskedLM (class in jaxgarden.models)
ModernBertLayer (class in jaxgarden.models)
ModernBertMLP (class in jaxgarden.models)
module
jaxgarden.attention
jaxgarden.functional
jaxgarden.models
MultiHeadAttention (class in jaxgarden.attention)
multiple_of (jaxgarden.models.LlamaConfig attribute)
,
[1]
N
n_heads (jaxgarden.models.LlamaAttention attribute)
(jaxgarden.models.LlamaConfig attribute)
,
[1]
n_kv_heads (jaxgarden.models.LlamaAttention attribute)
(jaxgarden.models.LlamaConfig attribute)
,
[1]
n_layers (jaxgarden.models.LlamaConfig attribute)
,
[1]
norm (jaxgarden.models.LlamaForCausalLM attribute)
norm_eps (jaxgarden.models.LlamaConfig attribute)
,
[1]
(jaxgarden.models.LlamaRMSNorm attribute)
norm_weights (jaxgarden.models.LlamaRMSNorm attribute)
num_attention_heads (jaxgarden.models.Gemma2Config attribute)
(jaxgarden.models.Gemma3Config attribute)
,
[1]
num_hidden_layers (jaxgarden.models.Gemma2Config attribute)
(jaxgarden.models.Gemma3Config attribute)
,
[1]
num_key_value_heads (jaxgarden.models.Gemma2Config attribute)
(jaxgarden.models.Gemma3Config attribute)
,
[1]
O
o_proj (jaxgarden.models.LlamaAttention attribute)
P
pad_token_id (jaxgarden.models.Gemma2Config attribute)
(jaxgarden.models.Gemma3Config attribute)
,
[1]
param_dtype (jaxgarden.models.Gemma2Config attribute)
(jaxgarden.models.Gemma3Config attribute)
post_attention_layernorm (jaxgarden.models.LlamaTransformerBlock attribute)
Q
q_proj (jaxgarden.models.LlamaAttention attribute)
query_pre_attn_scalar (jaxgarden.models.Gemma3Config attribute)
,
[1]
R
repeat_kv() (jaxgarden.models.LlamaAttention static method)
rms_norm_eps (jaxgarden.models.Gemma2Config attribute)
(jaxgarden.models.Gemma3Config attribute)
,
[1]
rope_local_base_freq (jaxgarden.models.Gemma3Config attribute)
,
[1]
rope_scaling (jaxgarden.models.Gemma3Config attribute)
,
[1]
rope_theta (jaxgarden.models.Gemma2Config attribute)
(jaxgarden.models.Gemma3Config attribute)
,
[1]
(jaxgarden.models.LlamaConfig attribute)
,
[1]
rotary_emb (jaxgarden.models.LlamaAttention attribute)
rotate_half() (jaxgarden.models.Gemma2Attention static method)
(jaxgarden.models.LlamaAttention static method)
S
save() (jaxgarden.models.BaseModel method)
seed (jaxgarden.models.BaseConfig attribute)
sliding_window (jaxgarden.models.Gemma3Config attribute)
,
[1]
sliding_window_pattern (jaxgarden.models.Gemma3Config attribute)
,
[1]
sliding_window_size (jaxgarden.models.Gemma2Config attribute)
state (jaxgarden.models.BaseModel property)
state_dict (jaxgarden.models.BaseModel property)
T
tie_word_embeddings (jaxgarden.models.Gemma2Config attribute)
(jaxgarden.models.Gemma3Config attribute)
,
[1]
to_dict() (jaxgarden.models.BaseConfig method)
token_embed (jaxgarden.models.LlamaForCausalLM attribute)
U
up_proj (jaxgarden.models.LlamaMLP attribute)
update() (jaxgarden.models.BaseConfig method)
use_cache (jaxgarden.models.Gemma3Config attribute)
,
[1]
V
v_proj (jaxgarden.models.LlamaAttention attribute)
vocab_size (jaxgarden.models.Gemma2Config attribute)
(jaxgarden.models.Gemma3Config attribute)
,
[1]
(jaxgarden.models.LlamaConfig attribute)
,
[1]