JAXgarden
Contents:
API Documentation
JAXgarden
Index
Index
_
|
A
|
B
|
C
|
D
|
E
|
F
|
G
|
H
|
I
|
J
|
K
|
L
|
M
|
N
|
O
|
P
|
Q
|
R
|
S
|
T
|
U
|
V
_
__call__() (jaxgarden.models.LlamaAttention method)
(jaxgarden.models.LlamaForCausalLM method)
(jaxgarden.models.LlamaMLP method)
(jaxgarden.models.LlamaRMSNorm method)
(jaxgarden.models.LlamaRotaryEmbedding method)
(jaxgarden.models.LlamaTransformerBlock method)
(jaxgarden.models.ModernBertAttention method)
(jaxgarden.models.ModernBertEmbeddings method)
(jaxgarden.models.ModernBERTEncoder method)
(jaxgarden.models.ModernBERTForMaskedLM method)
(jaxgarden.models.ModernBertLayer method)
(jaxgarden.models.ModernBertMLP method)
__init__() (jaxgarden.attention.MultiHeadAttention method)
(jaxgarden.models.BaseModel method)
(jaxgarden.models.LlamaAttention method)
(jaxgarden.models.LlamaForCausalLM method)
(jaxgarden.models.LlamaMLP method)
(jaxgarden.models.LlamaRMSNorm method)
(jaxgarden.models.LlamaRotaryEmbedding method)
(jaxgarden.models.LlamaTransformerBlock method)
(jaxgarden.models.ModernBertAttention method)
(jaxgarden.models.ModernBertEmbeddings method)
(jaxgarden.models.ModernBERTEncoder method)
(jaxgarden.models.ModernBERTForMaskedLM method)
(jaxgarden.models.ModernBertLayer method)
(jaxgarden.models.ModernBertMLP method)
A
apply_rotary_pos_emb() (jaxgarden.models.LlamaAttention method)
attention (jaxgarden.models.LlamaTransformerBlock attribute)
B
base (jaxgarden.models.LlamaRotaryEmbedding attribute)
BaseConfig (class in jaxgarden.models)
BaseModel (class in jaxgarden.models)
C
convert_weights_from_hf() (jaxgarden.models.BaseModel method)
(jaxgarden.models.LlamaForCausalLM method)
D
dim (jaxgarden.models.LlamaConfig attribute)
,
[1]
(jaxgarden.models.LlamaRotaryEmbedding attribute)
dot_product_attention() (in module jaxgarden.functional)
down_proj (jaxgarden.models.LlamaMLP attribute)
download_from_hf() (jaxgarden.models.BaseModel static method)
E
extra (jaxgarden.models.BaseConfig attribute)
F
from_hf() (jaxgarden.models.BaseModel method)
G
gate_proj (jaxgarden.models.LlamaMLP attribute)
generate() (jaxgarden.models.GenerationMixin method)
GenerationMixin (class in jaxgarden.models)
H
head_dim (jaxgarden.models.LlamaAttention attribute)
(jaxgarden.models.LlamaConfig attribute)
,
[1]
I
input_layernorm (jaxgarden.models.LlamaTransformerBlock attribute)
intermediate_size (jaxgarden.models.LlamaConfig attribute)
,
[1]
iter_safetensors() (jaxgarden.models.BaseModel static method)
J
jaxgarden.attention
module
jaxgarden.functional
module
jaxgarden.models
module
K
k_proj (jaxgarden.models.LlamaAttention attribute)
L
layers (jaxgarden.models.LlamaForCausalLM attribute)
LlamaAttention (class in jaxgarden.models)
LlamaConfig (class in jaxgarden.models)
LlamaForCausalLM (class in jaxgarden.models)
LlamaMLP (class in jaxgarden.models)
LlamaRMSNorm (class in jaxgarden.models)
LlamaRotaryEmbedding (class in jaxgarden.models)
LlamaTransformerBlock (class in jaxgarden.models)
lm_head (jaxgarden.models.LlamaForCausalLM attribute)
load() (jaxgarden.models.BaseModel method)
log_level (jaxgarden.models.BaseConfig attribute)
M
mlp (jaxgarden.models.LlamaTransformerBlock attribute)
ModernBertAttention (class in jaxgarden.models)
ModernBertEmbeddings (class in jaxgarden.models)
ModernBERTEncoder (class in jaxgarden.models)
ModernBERTForMaskedLM (class in jaxgarden.models)
ModernBertLayer (class in jaxgarden.models)
ModernBertMLP (class in jaxgarden.models)
module
jaxgarden.attention
jaxgarden.functional
jaxgarden.models
MultiHeadAttention (class in jaxgarden.attention)
multiple_of (jaxgarden.models.LlamaConfig attribute)
,
[1]
N
n_heads (jaxgarden.models.LlamaAttention attribute)
(jaxgarden.models.LlamaConfig attribute)
,
[1]
n_kv_heads (jaxgarden.models.LlamaAttention attribute)
(jaxgarden.models.LlamaConfig attribute)
,
[1]
n_layers (jaxgarden.models.LlamaConfig attribute)
,
[1]
norm (jaxgarden.models.LlamaForCausalLM attribute)
norm_eps (jaxgarden.models.LlamaConfig attribute)
,
[1]
(jaxgarden.models.LlamaRMSNorm attribute)
norm_weights (jaxgarden.models.LlamaRMSNorm attribute)
O
o_proj (jaxgarden.models.LlamaAttention attribute)
P
post_attention_layernorm (jaxgarden.models.LlamaTransformerBlock attribute)
Q
q_proj (jaxgarden.models.LlamaAttention attribute)
R
repeat_kv() (jaxgarden.models.LlamaAttention method)
rope_theta (jaxgarden.models.LlamaConfig attribute)
,
[1]
rotary_emb (jaxgarden.models.LlamaAttention attribute)
rotate_half() (jaxgarden.models.LlamaAttention method)
S
save() (jaxgarden.models.BaseModel method)
seed (jaxgarden.models.BaseConfig attribute)
state (jaxgarden.models.BaseModel property)
state_dict (jaxgarden.models.BaseModel property)
T
to_dict() (jaxgarden.models.BaseConfig method)
token_embed (jaxgarden.models.LlamaForCausalLM attribute)
U
up_proj (jaxgarden.models.LlamaMLP attribute)
update() (jaxgarden.models.BaseConfig method)
V
v_proj (jaxgarden.models.LlamaAttention attribute)
vocab_size (jaxgarden.models.LlamaConfig attribute)
,
[1]