Welcome to JAXgarden documentation!
JAXgarden provides high-performance and hackable neural network model implementations in JAX, leveraging optimized kernals and layers like FlashAttention.
Features
MultiHeadAttention: A Flax NNX-compatible implementation with support for different attention backends.
Installation
pip install git+https://github.com/ml-gde/jax-layers.git
For development installation:
# first, fork the repository to your account.
# Then, clone it to your machine.
git clone https://github.com/yourusername/jax-layers.git
cd jax-layers
pip install -e ".[dev]"