diff --git a/gpu_multi_process_run.sh b/gpu_multi_process_run.sh new file mode 100644 index 0000000..efe6b9c --- /dev/null +++ b/gpu_multi_process_run.sh @@ -0,0 +1,156 @@ +#! /bin/bash +set -e +set -u +set -o pipefail + +: "${NNODES:?Must set NNODES}" +: "${NODE_RANK:?Must set NODE_RANK}" +: "${JAX_COORDINATOR_PORT:?Must set JAX_COORDINATOR_PORT}" +: "${JAX_COORDINATOR_ADDRESS:?Must set JAX_COORDINATOR_ADDRESS}" +: "${GPUS_PER_NODE:?Must set GPUS_PER_NODE}" +: "${COMMAND:?Must set COMMAND}" + + +export GPUS_PER_NODE=$GPUS_PER_NODE +export JAX_COORDINATOR_PORT=$JAX_COORDINATOR_PORT +export JAX_COORDINATOR_ADDRESS=$JAX_COORDINATOR_ADDRESS + +set_nccl_gpudirect_tcpx_specific_configuration() { + if [[ "$USE_GPUDIRECT" == "tcpx" ]] || [[ "$USE_GPUDIRECT" == "fastrak" ]]; then + export CUDA_DEVICE_MAX_CONNECTIONS=1 + export NCCL_CROSS_NIC=0 + export NCCL_DEBUG=INFO + export NCCL_DYNAMIC_CHUNK_SIZE=524288 + export NCCL_NET_GDR_LEVEL=PIX + export NCCL_NVLS_ENABLE=0 + export NCCL_P2P_NET_CHUNKSIZE=524288 + export NCCL_P2P_NVL_CHUNKSIZE=1048576 + export NCCL_P2P_PCI_CHUNKSIZE=524288 + export NCCL_PROTO=Simple + export NCCL_SOCKET_IFNAME=eth0 + export NVTE_FUSED_ATTN=1 + export TF_CPP_MAX_LOG_LEVEL=100 + export TF_CPP_VMODULE=profile_guided_latency_estimator=10 + export XLA_PYTHON_CLIENT_MEM_FRACTION=0.85 + shopt -s globstar nullglob + IFS=:$IFS + set -- /usr/local/cuda-*/compat + export LD_LIBRARY_PATH="${1+:"$*"}:${LD_LIBRARY_PATH}:/usr/local/tcpx/lib64" + IFS=${IFS#?} + shopt -u globstar nullglob + + if [[ "$USE_GPUDIRECT" == "tcpx" ]]; then + echo "Using GPUDirect-TCPX" + export NCCL_ALGO=Ring + export NCCL_DEBUG_SUBSYS=INIT,GRAPH,ENV,TUNING,NET,VERSION + export NCCL_GPUDIRECTTCPX_CTRL_DEV=eth0 + export NCCL_GPUDIRECTTCPX_FORCE_ACK=0 + export NCCL_GPUDIRECTTCPX_PROGRAM_FLOW_STEERING_WAIT_MICROS=1000000 + export NCCL_GPUDIRECTTCPX_RX_BINDINGS="eth1:22-35,124-139;eth2:22-35,124-139;eth3:74-87,178-191;eth4:74-87,178-191" + export NCCL_GPUDIRECTTCPX_SOCKET_IFNAME=eth1,eth2,eth3,eth4 + export NCCL_GPUDIRECTTCPX_TX_BINDINGS="eth1:8-21,112-125;eth2:8-21,112-125;eth3:60-73,164-177;eth4:60-73,164-177" + export NCCL_GPUDIRECTTCPX_TX_COMPLETION_NANOSLEEP=1000 + export NCCL_MAX_NCHANNELS=12 + export NCCL_MIN_NCHANNELS=12 + export NCCL_NSOCKS_PERTHREAD=4 + export NCCL_P2P_PXN_LEVEL=0 + export NCCL_SOCKET_NTHREADS=1 + elif [[ "$USE_GPUDIRECT" == "fastrak" ]]; then + echo "Using GPUDirect-TCPFasTrak" + export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 + export NCCL_ALGO=Ring,Tree + export NCCL_BUFFSIZE=8388608 + export NCCL_FASTRAK_CTRL_DEV=eth0 + export NCCL_FASTRAK_ENABLE_CONTROL_CHANNEL=0 + export NCCL_FASTRAK_ENABLE_HOTPATH_LOGGING=0 + export NCCL_FASTRAK_IFNAME=eth1,eth2,eth3,eth4,eth5,eth6,eth7,eth8 + export NCCL_FASTRAK_NUM_FLOWS=2 + export NCCL_FASTRAK_USE_LLCM=1 + export NCCL_FASTRAK_USE_SNAP=1 + export NCCL_MIN_NCHANNELS=4 + export NCCL_SHIMNET_GUEST_CONFIG_CHECKER_CONFIG_FILE=/usr/local/nvidia/lib64/a3plus_guest_config.textproto + export NCCL_TUNER_CONFIG_PATH=/usr/local/nvidia/lib64/a3plus_tuner_config.textproto + export NCCL_TUNER_PLUGIN=libnccl-tuner.so + fi + else + echo "NOT using GPUDirect" + fi +} + +echo "LD_LIBRARY_PATH ${LD_LIBRARY_PATH}" + +set_nccl_gpudirect_tcpx_specific_configuration + +wait_all_success_or_exit() { + # https://www.baeldung.com/linux/background-process-get-exit-code + local pids=("$@") + while [[ ${#pids[@]} -ne 0 ]]; do + all_success="true" + for pid in "${pids[@]}"; do + code=$(non_blocking_wait "$pid") + if [[ $code -ne 127 ]]; then + if [[ $code -ne 0 ]]; then + echo "PID $pid failed with exit code $code" + exit "$code" + fi + else + all_success="false" + fi + done + if [[ $all_success == "true" ]]; then + echo "All pids succeeded" + break + fi + sleep 5 + done +} +non_blocking_wait() { + # https://www.baeldung.com/linux/background-process-get-exit-code + local pid=$1 + local code=127 # special code to indicate not-finished + if [[ ! -d "/proc/$pid" ]]; then + wait "$pid" + code=$? + fi + echo $code +} + +resolve_coordinator_ip() { + local lookup_attempt=1 + local max_coordinator_lookups=500 + local coordinator_found=false + local coordinator_ip_address="" + + echo "Coordinator Address $JAX_COORDINATOR_ADDRESS" + + while [[ "$coordinator_found" = false && $lookup_attempt -le $max_coordinator_lookups ]]; do + coordinator_ip_address=$(nslookup "$JAX_COORDINATOR_ADDRESS" 2>/dev/null | awk '/^Address: / { print $2 }' | head -n 1) + if [[ -n "$coordinator_ip_address" ]]; then + coordinator_found=true + echo "Coordinator IP address: $coordinator_ip_address" + export JAX_COORDINATOR_IP=$coordinator_ip_address + return 0 + else + echo "Failed to recognize coordinator address $JAX_COORDINATOR_ADDRESS on attempt $lookup_attempt, retrying..." + ((lookup_attempt++)) + sleep 1 + fi + done + + if [[ "$coordinator_found" = false ]]; then + echo "Failed to resolve coordinator address after $max_coordinator_lookups attempts." + return 1 + fi +} + +# Resolving coordinator IP +set +e +resolve_coordinator_ip +set -e + +PIDS=() +eval ${COMMAND} & +PID=$! +PIDS+=($PID) + +wait_all_success_or_exit "${PIDS[@]}" diff --git a/maxdiffusion_gpu_dependencies.Dockerfile b/maxdiffusion_gpu_dependencies.Dockerfile index 3cd770c..854ea48 100644 --- a/maxdiffusion_gpu_dependencies.Dockerfile +++ b/maxdiffusion_gpu_dependencies.Dockerfile @@ -2,6 +2,8 @@ ARG BASEIMAGE=ghcr.io/nvidia/jax:base FROM $BASEIMAGE +ENV PYTHONPATH="${PYTHONPATH}:src" + # Stopgaps measure to circumvent gpg key setup issue. RUN echo "deb [trusted=yes] https://developer.download.nvidia.com/devtools/repos/ubuntu2204/amd64/ /" > /etc/apt/sources.list.d/devtools-ubuntu2204-amd64.list diff --git a/setup.py b/setup.py index 114fbbd..d2b2643 100644 --- a/setup.py +++ b/setup.py @@ -140,7 +140,6 @@ # # some of the values are versioned whereas others aren't. deps = {b: a for a, b in (re.findall(r"^(([^!=<>~]+)(?:[!=<>~].*)?$)", x)[0] for x in _deps)} - # since we save this data in src/maxdiffusion/dependency_versions_table.py it can be easily accessed from # anywhere. If you need to quickly access the data from this table in a shell, you can do so easily with: #