Build a Transformer with JAX

Estimated read time 2 min read

Post Content

​ General purpose transformer architecture has really “transformed” the AI landscape. Learn about its origins and structure, and see it built from scratch! We’ll walk through building a small transformer on JAX, using Flax NNX to build the model architecture, Optax for loss function and optimizer creation, and training on accelerated hardware with the help of Orbax and XLA. Get a taste of development on JAX, and prepare to take your own next steps in building and training AI models.

Resources:
Colab notebook – Transformers workshop → https://goo.gle/44w7bI9
Kaggle notebook – Transformers workshop → https://goo.gle/4dc12mS

What is JAX → https://goo.gle/4j6UQ0G
“Attention is All You Need” research paper → https://goo.gle/3Z1a7ZZ
Open Web Text dataset → https://goo.gle/3GKsUSM
The Illustrated Transformer → https://goo.gle/452N0BP
All the Transformer Math You Need to Know → https://goo.gle/4m0iFKe

JAX docs → https://goo.gle/452mUyL
JAX AI Stack → https://goo.gle/3GMHPvH
Colab Notebooks → https://goo.gle/4m7JZpR
Tips for using TPUs on Kaggle → https://goo.gle/4maY8CU

Speaker: Yufeng Guo

Check out all the keynote sessions from Google I/O 2025 → https://goo.gle/io25-keynote-sessions
Check out the AI session track from Google I/O 2025 → https://goo.gle/io25-ai-yt
Check out all of the sessions from Google I/O 2025→ https://goo.gle/io25-sessions-yt

Subscribe to Google for Developers → https://goo.gle/developers

Event: Google I/O 2025

Products Mentioned: AI/Machine Learning   Read More Google for Developers 

You May Also Like

More From Author