{"id":577,"date":"2023-06-21T12:25:56","date_gmt":"2023-06-21T10:25:56","guid":{"rendered":"https:\/\/www.wolter.tech\/?p=577"},"modified":"2023-06-21T12:25:56","modified_gmt":"2023-06-21T10:25:56","slug":"jax-on-juwels-booster","status":"publish","type":"post","link":"https:\/\/www.wolter.tech\/?p=577","title":{"rendered":"Jax on Juwels Booster"},"content":{"rendered":"\n<p>This post illustrates a possible way to set up multinode <a rel=\"noreferrer noopener\" href=\"https:\/\/jax.readthedocs.io\/en\/latest\/index.html\" target=\"_blank\">Jax<\/a> computations on the <a rel=\"noreferrer noopener\" href=\"https:\/\/apps.fz-juelich.de\/jsc\/hps\/juwels\/booster-overview.html\" target=\"_blank\">Juwels Booster<\/a> partition at the J\u00fclich Supercomputing Centre.<\/p>\n\n\n\n<p>The following text adapts instructions from <a rel=\"noreferrer noopener\" href=\"https:\/\/jax.readthedocs.io\/en\/latest\/multi_process.html\" target=\"_blank\">official documentation<\/a> to run in J\u00fclich. Let&#8217;s start with the Python code.  The code snippet below determines how many GPUs we have and tells Jax to run on multiple nodes in parallel. Finally, a sum is computed using every device in the cluster.<\/p>\n\n\n\n<pre class=\"wp-block-code has-small-font-size\"><code># test.py file content.\n# The following is run in parallel on each host.\nimport os\nimport socket\nimport jax\n\nnode_id = os.environ&#91;'SLURM_NODEID']\nvisible_devices = &#91;int(gpu) for gpu in os.environ&#91;'CUDA_VISIBLE_DEVICES'].split(',')]\n\ndef print_on_node_0(s):\n    if node_id == '0':\n        print(s)\n\nprint(f\"Process now in Python.\")\nprint(f\"Nodeid: {node_id}\")\nprint(f\"Host: {os.environ&#91;'HOSTNAME']}\")\nprint(f\"cuda visible devices: {visible_devices}\")\n\n# booster nodes use 4 GPUs per machine.\njax.distributed.initialize(local_device_ids=visible_devices)\n# total number of accelerator devices in the cluster\nprint_on_node_0(f'total device count: {jax.device_count()}') \n# number of accelerator devices attached to this host\nprint_on_node_0(f'local device count: {jax.local_device_count()}')\nprint_on_node_0(\"Device list:\")\nprint_on_node_0(jax.devices())\n\n# The psum is performed over all mapped devices across the pod slice\nprint_on_node_0(f'Computing pmaps over all devices')\nxs = jax.numpy.ones(jax.local_device_count())\nprint_on_node_0(f'Pmap result')\npres = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(xs)\nprint_on_node_0(pres)\n\n# we are done.\njax.distributed.shutdown()<\/code><\/pre>\n\n\n\n<p>The code snipped above runs on all nodes. To avoid log entry duplication, only node zero gets to print.<\/p>\n\n\n\n<p>To launch our Python code, we require a shell script for <a rel=\"noreferrer noopener\" href=\"https:\/\/slurm.schedmd.com\/sbatch.html\" target=\"_blank\">sbatch<\/a>.<\/p>\n\n\n\n<pre class=\"wp-block-code has-small-font-size\"><code>#!\/bin\/bash\n#\n#SBATCH -A TODO:enter-your-project-here\n#SBATCH --nodes=3\n#SBATCH --job-name=test_multi_node\n#SBATCH --output=test_multi_node-%j.out\n#SBATCH --error=test_multi_node-%j.err\n#SBATCH --time=00:20:00\n#SBATCH --gres gpu:4\n#SBATCH --partition develbooster\n\necho \"Got nodes:\"\necho $SLURM_JOB_NODELIST\necho \"Jobs per node:\"\necho $SLURM_JOB_NUM_NODES \n\nmodule load Python\nml CUDA\/.12.0\n\nexport LD_LIBRARY_PATH=\/p\/home\/jusers\/wolter1\/juwels\/project_drive\/cudnn-linux-x86_64-8.9.1.23_cuda12-archive\/lib:$LD_LIBRARY_PATH\n\nsource \/p\/home\/jusers\/wolter1\/juwels\/project_drive\/jax_env\/bin\/activate\n\nsrun --nodes=3 --gres gpu:4 python test.py\n<\/code><\/pre>\n\n\n\n<p>To make this runfile work for you, all paths will have to lead to locations where you stored the required libraries. When I wrote this post cudnn was not yet available for CUDA 12. I expect the situation to change in the near future. When it does, the <code>export<\/code> line can be replaced with a simple module load command.<\/p>\n\n\n\n<p>The most important part is the one starting with <code>source<\/code>. It activates a virtual environment with a Jax installation. See the <a rel=\"noreferrer noopener\" href=\"https:\/\/docs.python.org\/3\/library\/venv.html\" target=\"_blank\">Python docs<\/a> for more information. After setting up an environment, install jax for local CUDA, as described in their <a href=\"https:\/\/github.com\/google\/jax\/tree\/fc5960f2b8b7a0ef74dbae4e27c5c08ff1564cff#pip-installation-gpu-cuda-installed-locally-harder\" target=\"_blank\" rel=\"noreferrer noopener\">project readme<\/a>.<\/p>\n\n\n\n<p>Finally, running the run file with <code>sbatch<\/code> produces the following output:<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>Got nodes:\njwb&#91;0097,0117,0129]\nJobs per node:\n3\nProcess now in Python.\nNodeid: 2\nHost: jwb0129.juwels\ncuda visible devices: &#91;0, 1, 2, 3]\nProcess now in Python.\nNodeid: 1\nHost: jwb0117.juwels\ncuda visible devices: &#91;0, 1, 2, 3]\nProcess now in Python.\nNodeid: 0\nHost: jwb0097.juwels\ncuda visible devices: &#91;0, 1, 2, 3]\ntotal device count: 12\nlocal device count: 4\nDevice list:\n&#91;StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0), StreamExecutorGpuDevice(id=1, process_index=0, slice_index=0), StreamExecutorGpuDevice(id=2, process_index=0, slice_index=0), StreamExecutorGpuDevice(id=3, process_index=0, slice_index=0), StreamExecutorGpuDevice(id=4, process_index=1, slice_index=1), StreamExecutorGpuDevice(id=5, process_index=1, slice_index=1), StreamExecutorGpuDevice(id=6, process_index=1, slice_index=1), StreamExecutorGpuDevice(id=7, process_index=1, slice_index=1), StreamExecutorGpuDevice(id=8, process_index=2, slice_index=2), StreamExecutorGpuDevice(id=9, process_index=2, slice_index=2), StreamExecutorGpuDevice(id=10, process_index=2, slice_index=2), StreamExecutorGpuDevice(id=11, process_index=2, slice_index=2)]\nComputing pmaps over all devices\nPmap result\n&#91;12. 12. 12. 12.]<\/code><\/pre>\n\n\n\n<p>Which suggests all twelve GPUs are recognized correctly. I would like to thank Stefan Kesselheim for helping me make this work!<\/p>\n","protected":false},"excerpt":{"rendered":"<p>This post illustrates a possible way to set up multinode Jax computations on the Juwels Booster partition at the J\u00fclich Supercomputing Centre. The following text adapts instructions from official documentation to run in J\u00fclich. Let&#8217;s start with the Python code. The code snippet below determines how many GPUs we have and tells Jax to run &hellip; <\/p>\n<p class=\"link-more\"><a href=\"https:\/\/www.wolter.tech\/?p=577\" class=\"more-link\">Continue reading<span class=\"screen-reader-text\"> &#8220;Jax on Juwels Booster&#8221;<\/span><\/a><\/p>\n","protected":false},"author":1,"featured_media":0,"comment_status":"closed","ping_status":"","sticky":false,"template":"","format":"standard","meta":{"footnotes":""},"categories":[5],"tags":[],"class_list":["post-577","post","type-post","status-publish","format-standard","hentry","category-research-projects","entry"],"_links":{"self":[{"href":"https:\/\/www.wolter.tech\/index.php?rest_route=\/wp\/v2\/posts\/577","targetHints":{"allow":["GET"]}}],"collection":[{"href":"https:\/\/www.wolter.tech\/index.php?rest_route=\/wp\/v2\/posts"}],"about":[{"href":"https:\/\/www.wolter.tech\/index.php?rest_route=\/wp\/v2\/types\/post"}],"author":[{"embeddable":true,"href":"https:\/\/www.wolter.tech\/index.php?rest_route=\/wp\/v2\/users\/1"}],"replies":[{"embeddable":true,"href":"https:\/\/www.wolter.tech\/index.php?rest_route=%2Fwp%2Fv2%2Fcomments&post=577"}],"version-history":[{"count":3,"href":"https:\/\/www.wolter.tech\/index.php?rest_route=\/wp\/v2\/posts\/577\/revisions"}],"predecessor-version":[{"id":580,"href":"https:\/\/www.wolter.tech\/index.php?rest_route=\/wp\/v2\/posts\/577\/revisions\/580"}],"wp:attachment":[{"href":"https:\/\/www.wolter.tech\/index.php?rest_route=%2Fwp%2Fv2%2Fmedia&parent=577"}],"wp:term":[{"taxonomy":"category","embeddable":true,"href":"https:\/\/www.wolter.tech\/index.php?rest_route=%2Fwp%2Fv2%2Fcategories&post=577"},{"taxonomy":"post_tag","embeddable":true,"href":"https:\/\/www.wolter.tech\/index.php?rest_route=%2Fwp%2Fv2%2Ftags&post=577"}],"curies":[{"name":"wp","href":"https:\/\/api.w.org\/{rel}","templated":true}]}}