From a5ee4cf3b66c97dd576d7c81d82b45ef8fb13209 Mon Sep 17 00:00:00 2001 From: _Sizhi Tan_ Date: Tue, 15 Oct 2024 00:02:15 +0000 Subject: [PATCH] unify working directory and add more gpu dependency files --- docker_build_dependency_image.sh | 5 +++++ docs/getting_started/run_maxdiffusion_via_xpk.md | 9 +++++++++ maxdiffusion_dependencies.Dockerfile | 2 +- maxdiffusion_jax_stable_stack_tpu.Dockerfile | 4 ++-- maxdiffusion_runner.Dockerfile | 4 ++-- 5 files changed, 19 insertions(+), 5 deletions(-) diff --git a/docker_build_dependency_image.sh b/docker_build_dependency_image.sh index 5660049..10ede04 100644 --- a/docker_build_dependency_image.sh +++ b/docker_build_dependency_image.sh @@ -61,6 +61,11 @@ COMMIT_HASH=$(git rev-parse --short HEAD) echo "Building MaxDiffusion with MODE=${MODE} at commit hash ${COMMIT_HASH} . . ." if [[ ${DEVICE} == "gpu" ]]; then + if [[ ${MODE} == "pinned" ]]; then + export BASEIMAGE=ghcr.io/nvidia/jax:base-2024-05-07 + else + export BASEIMAGE=ghcr.io/nvidia/jax:base + fi docker build --network host --build-arg MODE=${MODE} --build-arg JAX_VERSION=$JAX_VERSION --build-arg DEVICE=$DEVICE --build-arg BASEIMAGE=$BASEIMAGE -f ./maxdiffusion_gpu_dependencies.Dockerfile -t ${LOCAL_IMAGE_NAME} . else if [[ "${MODE}" == "stable_stack" ]]; then diff --git a/docs/getting_started/run_maxdiffusion_via_xpk.md b/docs/getting_started/run_maxdiffusion_via_xpk.md index 87d43f2..9520dcd 100644 --- a/docs/getting_started/run_maxdiffusion_via_xpk.md +++ b/docs/getting_started/run_maxdiffusion_via_xpk.md @@ -80,6 +80,15 @@ after which log out and log back in to the machine. **Important Note:** The JAX Stable Stack is currently in the experimental phase. We encourage you to try it out and provide feedback. + + #### Run MaxDiffusion on GPU + Default device is TPU. To run MaxDiffusion on GPU, please explicitly specify GPU When building docker image. + + ```shell + # Default will pick base image. + bash docker_build_dependency_image.sh DEVICE=gpu + ``` + 3. After building the dependency image `maxdiffusion_base_image`, xpk can handle updates to the working directory when running `xpk workload create` and using `--base-docker-image`. See details on docker images in xpk here: https://github.com/google/xpk/blob/main/README.md#how-to-add-docker-images-to-a-xpk-workload diff --git a/maxdiffusion_dependencies.Dockerfile b/maxdiffusion_dependencies.Dockerfile index 778e9e4..c12707c 100644 --- a/maxdiffusion_dependencies.Dockerfile +++ b/maxdiffusion_dependencies.Dockerfile @@ -45,7 +45,7 @@ ARG JAX_VERSION ENV ENV_JAX_VERSION=$JAX_VERSION # Set the working directory in the container -WORKDIR /app +WORKDIR /deps # Copy all files from local workspace into docker container COPY . . diff --git a/maxdiffusion_jax_stable_stack_tpu.Dockerfile b/maxdiffusion_jax_stable_stack_tpu.Dockerfile index ff883b9..00bfcc2 100644 --- a/maxdiffusion_jax_stable_stack_tpu.Dockerfile +++ b/maxdiffusion_jax_stable_stack_tpu.Dockerfile @@ -7,10 +7,10 @@ ARG COMMIT_HASH ENV COMMIT_HASH=$COMMIT_HASH -RUN mkdir -p /app +RUN mkdir -p /deps # Set the working directory in the container -WORKDIR /app +WORKDIR /deps # Copy all files from local workspace into docker container COPY . . diff --git a/maxdiffusion_runner.Dockerfile b/maxdiffusion_runner.Dockerfile index 50a93e3..25bb66a 100644 --- a/maxdiffusion_runner.Dockerfile +++ b/maxdiffusion_runner.Dockerfile @@ -2,9 +2,9 @@ ARG BASEIMAGE=maxdiffusion_base_image FROM $BASEIMAGE # Set the working directory in the container -WORKDIR /app +WORKDIR /deps # Copy all files from local workspace into docker container COPY . . -WORKDIR /app \ No newline at end of file +WORKDIR /deps \ No newline at end of file