Jax on Juwels Booster

This post illustrates a possible way to set up multinode Jax computations on the Juwels Booster partition at the Jülich Supercomputing Centre.

The following text adapts instructions from official documentation to run in Jülich. Let’s start with the Python code. The code snippet below determines how many GPUs we have and tells Jax to run on multiple nodes in parallel. Finally, a sum is computed using every device in the cluster.

# test.py file content.
# The following is run in parallel on each host.
import os
import socket
import jax

node_id = os.environ['SLURM_NODEID']
visible_devices = [int(gpu) for gpu in os.environ['CUDA_VISIBLE_DEVICES'].split(',')]

def print_on_node_0(s):
    if node_id == '0':
        print(s)

print(f"Process now in Python.")
print(f"Nodeid: {node_id}")
print(f"Host: {os.environ['HOSTNAME']}")
print(f"cuda visible devices: {visible_devices}")

# booster nodes use 4 GPUs per machine.
jax.distributed.initialize(local_device_ids=visible_devices)
# total number of accelerator devices in the cluster
print_on_node_0(f'total device count: {jax.device_count()}') 
# number of accelerator devices attached to this host
print_on_node_0(f'local device count: {jax.local_device_count()}')
print_on_node_0("Device list:")
print_on_node_0(jax.devices())

# The psum is performed over all mapped devices across the pod slice
print_on_node_0(f'Computing pmaps over all devices')
xs = jax.numpy.ones(jax.local_device_count())
print_on_node_0(f'Pmap result')
pres = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(xs)
print_on_node_0(pres)

# we are done.
jax.distributed.shutdown()

The code snipped above runs on all nodes. To avoid log entry duplication, only node zero gets to print.

To launch our Python code, we require a shell script for sbatch.

#!/bin/bash
#
#SBATCH -A TODO:enter-your-project-here
#SBATCH --nodes=3
#SBATCH --job-name=test_multi_node
#SBATCH --output=test_multi_node-%j.out
#SBATCH --error=test_multi_node-%j.err
#SBATCH --time=00:20:00
#SBATCH --gres gpu:4
#SBATCH --partition develbooster

echo "Got nodes:"
echo $SLURM_JOB_NODELIST
echo "Jobs per node:"
echo $SLURM_JOB_NUM_NODES 

module load Python
ml CUDA/.12.0

export LD_LIBRARY_PATH=/p/home/jusers/wolter1/juwels/project_drive/cudnn-linux-x86_64-8.9.1.23_cuda12-archive/lib:$LD_LIBRARY_PATH

source /p/home/jusers/wolter1/juwels/project_drive/jax_env/bin/activate

srun --nodes=3 --gres gpu:4 python test.py

To make this runfile work for you, all paths will have to lead to locations where you stored the required libraries. When I wrote this post cudnn was not yet available for CUDA 12. I expect the situation to change in the near future. When it does, the export line can be replaced with a simple module load command.

The most important part is the one starting with source. It activates a virtual environment with a Jax installation. See the Python docs for more information. After setting up an environment, install jax for local CUDA, as described in their project readme.

Finally, running the run file with sbatch produces the following output:

Got nodes:
jwb[0097,0117,0129]
Jobs per node:
3
Process now in Python.
Nodeid: 2
Host: jwb0129.juwels
cuda visible devices: [0, 1, 2, 3]
Process now in Python.
Nodeid: 1
Host: jwb0117.juwels
cuda visible devices: [0, 1, 2, 3]
Process now in Python.
Nodeid: 0
Host: jwb0097.juwels
cuda visible devices: [0, 1, 2, 3]
total device count: 12
local device count: 4
Device list:
[StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0), StreamExecutorGpuDevice(id=1, process_index=0, slice_index=0), StreamExecutorGpuDevice(id=2, process_index=0, slice_index=0), StreamExecutorGpuDevice(id=3, process_index=0, slice_index=0), StreamExecutorGpuDevice(id=4, process_index=1, slice_index=1), StreamExecutorGpuDevice(id=5, process_index=1, slice_index=1), StreamExecutorGpuDevice(id=6, process_index=1, slice_index=1), StreamExecutorGpuDevice(id=7, process_index=1, slice_index=1), StreamExecutorGpuDevice(id=8, process_index=2, slice_index=2), StreamExecutorGpuDevice(id=9, process_index=2, slice_index=2), StreamExecutorGpuDevice(id=10, process_index=2, slice_index=2), StreamExecutorGpuDevice(id=11, process_index=2, slice_index=2)]
Computing pmaps over all devices
Pmap result
[12. 12. 12. 12.]

Which suggests all twelve GPUs are recognized correctly. I would like to thank Stefan Kesselheim for helping me make this work!