Experience of tuning Llama3 405B on AMD MI300x
Contents
- 1 Introduction
- 1.1 What is JAX and why we chose it
- 1.2 Why I love JAX so much:
- 1.3 JAX is especially great when working with non-NVIDIA hardware:
- 1.4 Preparing JAX for AMD is very easy!
- 1.5 LLaMA 405B Training: Performance and Scalability
- 1.6 Our system for learning
- 1.7 Loading the sharding model and parameters
- 1.8 Visualization of sharding
- 1.9 Rules of division
- 1.10 Method of parameter sharding
- 1.11 Applying sharding restrictions
- 1.12 Sharding Batch training
- 1.13 Implementation of LoRA training
- 1.14 LoRA matrices A (lora_a)
- 1.15 LoRA B matrices (lora_b)
- 1.16 Conclusion
Introduction
Open-source models are getting bigger and bigger, so the need for a reliable infrastructure to run large-scale AI training is higher than ever. Recently, our company performed fine-tuning of the model LLaMA 3.1 405B on AMD GPUs, proving their ability to effectively handle large-scale AI tasks. Our experience has been very positive and we are happy to open source all our work on GitHub.
AMD GPUs, and especially the MI300X series, are a solid alternative to NVIDIA’s AI hardware, providing more performance per dollar invested. Our system consisted of one node with 8 AMD MI300x GPUsand for fine-tuning we used JAX. In this article, we will tell the whole story of fine-tuning LLaMA 405B, including the details of parameter sharding and LoRA implementation.
What is JAX and why we chose it
JAX is a powerful machine learning library that combines NumPy-like APIs, automatic differentiation, and the Google XLA compiler. It has great APIs for parallelizing models, ideal for training huge models like LLaMA 3.1 405B.
Why I love JAX so much:
-
Pure functions: JAX encourages writing clean functions (if you want to JIT compile your code), which makes code easier to compile, debug, and read.
-
Advanced parallelism: JAX’s flexible JIT APIs initially support advanced data and model parallelism, which is critical for large-scale learning.
-
Improving the cleanliness of codebases: the JAX design philosophy encourages writing code that is initially ported between hardware platforms (CPU, GPU, TPU), resulting in improved codebase cleanliness and maintainability.
If you want to learn more about the advantages of JAX over PyTorch, I recommend reading the post PyTorch is dead. Long live JAX.
JAX is especially great when working with non-NVIDIA hardware:
When working with AMD JAX provides many advantages:
-
A hardware-independent approach: JAX uses the Accelerated Linear Algebra (XLA) compiler, which compiles calculations into a hardware-independent intermediate representation (HLO graph). This allows for optimization and efficient execution without modification of the same JAX code on different hardware backends, including AMD GPUs.
-
Platform-independent optimizations: The XLA compiler performs hardware-independent optimizations, benefiting all supported platforms.
-
Simplified portability: When working with JAX, moving from NVIDIA to AMD (or other supported hardware) requires only minimal code changes. This makes it very different from PyTorch, which is more closely related to NVIDIA’s CUDA ecosystem.
-
PyTorch often uses CUDA-specific implementations (such as calls
torch.cuda
,scaled_dot_product_attention
). -
Although PyTorch supports other backends such as ROCm for AMD GPUs, porting the code can be difficult due to NVIDIA-specific code execution paths.
-
The process of “getting rid of NVIDIA” of the PyTorch code can increase complexity and hinder portability.
-
Preparing JAX for AMD is very easy!
Configuring JAX on AMD GPUs is a very simple process:
# Подтягиваем образ Docker:
docker pull rocm/jax:latest
# Запускаем контейнер Docker:
docker run -it -w /workspace --device=/dev/kfd --device=/dev/dri --group-add video \
--cap-add=SYS_PTRACE --security-opt seccomp=unconfined --shm-size 16G rocm/jax:latest
# Верифицируем установку:
python3 -c 'import jax; print(jax.devices())'
I was running an AMD node consisting of 8 AMD MI300x GPUs. Each of the MI300x had 192 GB of HBM3 memory. They perform very well compared to the new NVIDIA H100 GPUs. (See comparison below, source: TensorWave)
LLaMA 405B Training: Performance and Scalability
Using JAX I was able to train the model LLaMA 405B on AMD GPUs, achieving impressive results.
We fine-tuned LoRA with all model weights and lora parameters with precision bfloat16
with LoRA rank = 8 and LoRA alpha = 16:
-
Model size: the scales of the LLaMA model take approx 800 GB VRAM.
-
LoRA weights + optimizer state: approximately 400 GB VRAM.
-
Total VRAM usage: 77% of total VRAM, approx 1200 GB.
-
Limitation: Due to the large size of the 405B model, space for batch sizes and sequence lengths was limited. I used a batch size of 16 and a sequence length of 64.
-
JIT compilation: also, due to space limitations, I was unable to run the JIT-compiled version; apparently it needs a bit more space than the eager mode graph.
-
Learning speed: about 35 tokens per second in eager mode JAX (1 learning stage took 30 s)
-
Memory efficiency: stable around 70%
-
Scaling: When running JAX, scaling was roughly linear across all 8 GPUs.
Below are the GPU metrics, memory efficiency and results rocm-smi
for 8 GPUs at one training stage of the fine-tuning run:
the results rocm-smi
:
Device |
Temperature |
Power |
Sections |
Cooler |
Productivity |
PwrCap |
VRAM% |
GPU% |
---|---|---|---|---|---|---|---|---|
0 |
58.0°C |
232.0 W |
NPS1, SPX, 0 |
0% |
auto |
750.0 W |
77% |
27% |
1 |
58.0°C |
233.0 W |
NPS1, SPX, 0 |
0% |
auto |
750.0 W |
77% |
25% |
2 |
56.0°C |
236.0 W |
NPS1, SPX, 0 |
0% |
auto |
750.0 W |
77% |
24% |
3 |
52.0°C |
228.0 W |
NPS1, SPX, 0 |
0% |
auto |
750.0 W |
77% |
23% |
4 |
59.0°C |
232.0 W |
NPS1, SPX, 0 |
0% |
auto |
750.0 W |
77% |
22% |
5 |
51.0°C |
230.0 W |
NPS1, SPX, 0 |
0% |
auto |
750.0 W |
77% |
21% |
6 |
61.0°C |
235.0W |
NPS1, SPX, 0 |
0% |
auto |
750.0 W |
77% |
18% |
7 |
56.0°C |
227.0 W |
NPS1, SPX, 0 |
0% |
auto |
750.0 W |
77% |
18% |
Full details on GPU usage, VRAM and rocm-smi data can be found in our Github repository.
Our system for learning
We ported the LLaMA 3.1 architecture from PyTorch to JAX. Our implementation can be explored in the GitHub repository.
This migration opened up new opportunities for us in terms of performance and scalability.
Loading the sharding model and parameters
Working with such a huge model as the LLaMA 405B requires efficient sharding of parameters between several devices. Below we will explain how we achieved it using JAX.
Sharding options in JAX
To effectively distribute the huge LLaMA 405B model to 8 AMD GPUs, we used JAX’s device mesh (codepointer) function. A device mesh arranges the available devices into a multidimensional mesh, allowing us to specify how computation and data will be partitioned. In our system, we created a mesh of the form (1, 8, 1), namely with such axes as data parallelism (dp), fully sharded data parallelism (fsdp), and model parallelism parallelism, mp). We then applied specific sharding rules to the model parameters, specifying for each tensor model how its dimensions should be partitioned between the mesh axes.
DEVICES = jax.devices()
DEVICE_COUNT = len(DEVICES)
DEVICE_MESH = mesh_utils.create_device_mesh((1, 8, 1))
MESH = Mesh(devices=DEVICE_MESH, axis_names=("dp", "fsdp", "mp"))
Visualization of sharding
Array sharding can be visualized using jax.debug.visualize_array_sharding
. This is extremely useful for verifying the correct application of sharding specifications.
Rules of division
We defined the partitioning rules for the different components of the model:
Method of parameter sharding
Applying sharding restrictions
In the process of loading the model, we incrementally perform sharding of the model weight using special sharding functions:
def make_shard_and_gather_fns(partition_specs):
def make_shard_fn(partition_spec):
out_sharding = NamedSharding(mesh, partition_spec)
def shard_fn(tensor):
return jax.device_put(tensor, out_sharding).block_until_ready()
return shard_fn
shard_fns = jax.tree_util.tree_map(make_shard_fn, partition_specs)
return shard_fns
# Создаём функции шардинга на основании правил разбиения
shard_fns = make_shard_and_gather_fns(partitioning_rules)
This allows us to place each option on the appropriate devices with the specified sharding.
Sharding Batch training
First, the training batch is created in the usual way. Before passing its model, we shard it between GPUs according to the following code:
train_batch = jax.device_put(
train_batch, NamedSharding(self.mesh, PS("dp", "fsdp"))
)
Here we specify that training should be divided by sharding between data parallel axes ("dp"
) and fully sharded data parallel ("fsdp"
), which in our case correspond to 1, 8; this results in the following rendering:
Implementation of LoRA training
LoRA (Low-Rank Adaptation) reduces the number of parameters to learn by splitting weight updates into low-rank matrices. This is especially useful for fine-tuning large models.
Key aspects of our LoRA implementation:
-
Separate parameterization: we keep the LoRA parameters (lora_a and lora_b) separate from the main model parameters.
-
Gradient stop: We use jax.lax.stop_gradient(kernel) to prevent the core model weights from updating.
-
Efficient matrix multiplication: We use lax.dot_general for fast precision-controlled matrix operations.
-
Scale factor: LoRA raw data is scaled by (self.lora_alpha/self.lora_rank) before adding to the main raw data.
LoRADense layer
We implemented a special layer LoRADense
which includes LoRA parameters:
class LoRADense(nn.Module):
features: int
lora_rank: int = 8
lora_alpha: float = 16.0
@nn.compact
def __call__(self, inputs: Any) -> Any:
# Параметр исходного ядра (заморожен)
kernel = self.param('kernel', ...)
y = lax.dot_general(inputs, jax.lax.stop_gradient(kernel), ...)
# Параметры LoRA (обучаемые)
lora_a = self.variable('lora_params', 'lora_a', ..., ...)
lora_b = self.variable('lora_params', 'lora_b', ..., ...)
# Вычисление выходных данных LoRA
lora_output = lax.dot_general(inputs, lora_a.value, ...)
lora_output = lax.dot_general(lora_output, lora_b.value, ...)
# Комбинирование исходных выходных данных с модификациями LoRA
y += (self.lora_alpha / self.lora_rank) * lora_output
return y.astype(self.dtype)
LoRA parameter sharding
For efficient distribution of parameters LoRA between devices we applied special sharding rules using JAX. This ensures that the LoRA parameters align with the sharding parameters of the main model, while optimizing memory usage and computational efficiency.
LoRA matrices A (lora_a)
-
Used by us breakdown specification:
PS("fsdp", "mp")
. -
Visualization:
-
Axis sharding: sharding of lora_a parameters between axes will be performed as (8, 1), i.e. the first axis is divided by sharding into 8 devices (axis
fsdp
), and the second axis does not break.The illustration shows that the first axis is divided by sharding into 8 devices (axis
fsdp
), and the second axis is not broken.
-
LoRA B matrices (lora_b)
-
Used by us breakdown specification:
PS("mp", "fsdp")
. -
Visualization:
-
Axis sharding: sharding of lora_b parameters by layers will be performed as (1, 8), i.e. the second axis is divided by sharding into 8 devices (axis
fsdp
), and the first axis does not break.The illustration shows that the second axis is divided by sharding into 8 devices (axis
fsdp
), dividing the columns of the matrix.
-
Such a sharding strategy optimizes the distribution of parameters, reduces excessive consumption of resources for communications, and increases parallelism during training. It ensures that only a fraction of the LoRA parameters are contained on each device, enabling efficient scaling of large models like the LLaMA 405B.
Updating LoRA parameters only
To optimize training during fine-tuning of the LLaMA 405B model, we calculate gradients only for the LoRA parameters, leaving the parameters of the main model frozen. This approach reduces the amount of memory used and speeds up learning because we update fewer parameters. Implementation details can be viewed in our GitHub repository.
In our learning cycle, each step uses a batch of inputs through the model. Since only the LoRA parameters are trained, the model predictions and the calculated loss function depend only on these parameters. We then perform gradient backpropagation with the LoRA parameters. By focusing the update on only these parameters, we simplify the training process, which allows us to efficiently perform multi-GPU fine-tuning of extremely large models like the LLaMA 405B.
Conclusion
Fine-tuning of a huge model LLaMA 3.1 405B on AMD GPUs using JAX left an extremely positive impression on us. By using the powerful parallelism capabilities of JAX and its hardware-independent techniques, I was able to efficiently distribute the model across 8 AMD MI300x GPUs. The use of parameter sharding made it possible to effectively manage a huge amount of model parameters between devices, which ensured almost linear scalability and high memory efficiency.
This experience highlights the ability of AMD GPUs as a powerful alternative to NVIDIA hardware in large-scale AI training. The seamless integration of JAX with ROCm support eases the transition and opens up new opportunities for the AI research and development community. By sharing my experience and code, I hope it motivates others to explore and apply these tools in their own large-scale machine learning projects.