Skip to content

Commit

Permalink
Update XLA revision (#62)
Browse files Browse the repository at this point in the history
  • Loading branch information
jonatanklosko authored Nov 10, 2023
1 parent 9b22b7d commit af010a2
Show file tree
Hide file tree
Showing 6 changed files with 169 additions and 17 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ BUILD_MODE ?= opt # can also be dbg
BUILD_CACHE ?= $(TEMP)/xla_extension
OPENXLA_GIT_REPO ?= https://github.com/openxla/xla.git

OPENXLA_GIT_REV ?= b938cfdf2d4e9a5f69c494a316e92638c1a119ef
OPENXLA_GIT_REV ?= 771e38178340cbaaef8ff20f44da5407c15092cb

# Private configuration
BAZEL_FLAGS = --define "framework_shared_object=false" -c $(BUILD_MODE)
Expand Down
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,14 @@ For GPU support, we primarily rely on CUDA, because of the popularity and availa
in the cloud. In case you use ROCm and it does not work, please open up an issue and
we will be happy to help.

In addition to building in a local environment, you can build the ROCm binary using
the Docker-based scripts in [`builds/`](./builds/). You may want to adjust the ROCm
version in `rocm.Dockerfile` accordingly.

When you encounter errors at runtime, you may want to set `ROCM_PATH=/opt/rocm-5.7.0`
and `LD_LIBRARY_PATH="/opt/rocm-5.7.0/lib"` (with your respective version). For further
issues, feel free to open an issue.

#### `XLA_BUILD`

Defaults to `false`. If `true` the binary is built locally, which may be intended
Expand Down
44 changes: 40 additions & 4 deletions builds/build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ cd "$(dirname "$0")/.."
print_usage_and_exit() {
echo "Usage: $0 <variant>"
echo ""
echo "Compiles the project inside docker. Available variants: cpu, cuda118, cuda120."
echo "Compiles the project inside docker. Available variants: cpu, cuda118, cuda120, tpu, rocm."
exit 1
}

Expand All @@ -26,7 +26,23 @@ case "$1" in
--build-arg XLA_TARGET=cpu \
.

docker run --rm -v $(pwd)/builds/output/cpu/build:/build -v $(pwd)/builds/output/cpu/.cache:/root/.cache xla-cpu
docker run --rm \
-v $(pwd)/builds/output/cpu/build:/build \
-v $(pwd)/builds/output/cpu/.cache:/root/.cache \
$XLA_DOCKER_FLAGS \
xla-cpu
;;

"tpu")
docker build -t xla-tpu -f builds/cpu.Dockerfile \
--build-arg XLA_TARGET=tpu \
.

docker run --rm \
-v $(pwd)/builds/output/tpu/build:/build \
-v $(pwd)/builds/output/tpu/.cache:/root/.cache \
$XLA_DOCKER_FLAGS \
xla-tpu
;;

"cuda118")
Expand All @@ -36,7 +52,11 @@ case "$1" in
--build-arg XLA_TARGET=cuda118 \
.

docker run --rm -v $(pwd)/builds/output/cuda118/build:/build -v $(pwd)/builds/output/cuda118/.cache:/root/.cache xla-cuda118
docker run --rm \
-v $(pwd)/builds/output/cuda118/build:/build \
-v $(pwd)/builds/output/cuda118/.cache:/root/.cache \
$XLA_DOCKER_FLAGS \
xla-cuda118
;;

"cuda120")
Expand All @@ -46,7 +66,23 @@ case "$1" in
--build-arg XLA_TARGET=cuda120 \
.

docker run --rm -v $(pwd)/builds/output/cuda120/build:/build -v $(pwd)/builds/output/cuda120/.cache:/root/.cache xla-cuda120
docker run --rm \
-v $(pwd)/builds/output/cuda120/build:/build \
-v $(pwd)/builds/output/cuda120/.cache:/root/.cache \
$XLA_DOCKER_FLAGS \
xla-cuda120
;;

"rocm")
docker build -t xla-rocm -f builds/rocm.Dockerfile \
--build-arg XLA_TARGET=rocm \
.

docker run --rm \
-v $(pwd)/builds/output/rocm/build:/build \
-v $(pwd)/builds/output/rocm/.cache:/root/.cache \
$XLA_DOCKER_FLAGS \
xla-rocm
;;

*)
Expand Down
60 changes: 60 additions & 0 deletions builds/rocm.Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
FROM hexpm/elixir:1.15.4-erlang-26.0.2-ubuntu-focal-20230126 AS elixir

FROM rocm/dev-ubuntu-20.04:5.7-complete

# Set the missing UTF-8 locale, otherwise Elixir warns
ENV LC_ALL C.UTF-8

# Make sure installing packages (like tzdata) doesn't prompt for configuration
ENV DEBIAN_FRONTEND noninteractive

# We need to install "add-apt-repository" first
RUN apt-get update && apt-get install -y software-properties-common && \
# Add repository with the latest git version
add-apt-repository ppa:git-core/ppa && \
# Install basic system dependencies
apt-get update && apt-get install -y ca-certificates curl git unzip wget

# Install Bazel using Bazelisk (works for both amd and arm)
RUN wget -O bazel "https://github.com/bazelbuild/bazelisk/releases/download/v1.18.0/bazelisk-linux-$(dpkg --print-architecture)" && \
chmod +x bazel && \
mv bazel /usr/local/bin/bazel

ENV USE_BAZEL_VERSION 6.1.2

# Install Python and the necessary global dependencies
RUN apt-get install -y python3 python3-pip && \
ln -s /usr/bin/python3 /usr/bin/python && \
python -m pip install --upgrade pip numpy

# Install Erlang and Elixir

# Erlang runtime dependencies, see https://github.com/hexpm/bob/blob/3b5721dccdfe9d59766f374e7b4fb7fb8a7c720e/priv/scripts/docker/erlang-ubuntu-focal.dockerfile#L41-L45
RUN apt-get install -y --no-install-recommends libodbc1 libssl1.1 libsctp1

# We copy the top-level directory first to preserve symlinks in /usr/local/bin
COPY --from=elixir /usr/local /usr/ELIXIR_LOCAL
RUN cp -r /usr/ELIXIR_LOCAL/lib/* /usr/local/lib && \
cp -r /usr/ELIXIR_LOCAL/bin/* /usr/local/bin && \
rm -rf /usr/ELIXIR_LOCAL

# ---

ENV ROCM_PATH "/opt/rocm-5.7.0"

# ---

ARG XLA_TARGET

ENV XLA_TARGET=${XLA_TARGET}
ENV XLA_CACHE_DIR=/build
ENV XLA_BUILD=true

COPY mix.exs mix.lock ./
RUN mix deps.get

COPY lib lib
COPY Makefile Makefile.win ./
COPY extension extension

CMD [ "mix", "compile" ]
65 changes: 54 additions & 11 deletions extension/BUILD
Original file line number Diff line number Diff line change
@@ -1,24 +1,29 @@
load("//xla/stream_executor:build_defs.bzl", "if_cuda_or_rocm",)
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda",)
load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm",)
load("@tsl//tsl:tsl.bzl", "if_with_tpu_support")
load("@tsl//tsl:tsl.bzl", "tsl_grpc_cc_dependencies",)
load("@tsl//tsl:tsl.bzl", "transitive_hdrs",)
load("@rules_pkg//pkg:tar.bzl", "pkg_tar")
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_file")

package(default_visibility=["//visibility:private"])

# Static library which contains dependencies necessary for building on
# top of XLA
# Shared library which contains the subset of XLA required for EXLA
cc_binary(
name = "libxla_extension.so",
deps = [
"//xla:xla_proto_cc_impl",
"//xla:xla_data_proto_cc_impl",
"//xla/service:hlo_proto_cc_impl",
"//xla/service:memory_space_assignment_proto_cc_impl",
"//xla/service/memory_space_assignment:memory_space_assignment_proto_cc_impl",
"//xla/service:buffer_assignment_proto_cc_impl",
"//xla/service/gpu:backend_configs_cc_impl",
"//xla/service/gpu:hlo_op_profile_proto_cc_impl",
"//xla/stream_executor:dnn_proto_cc_impl",
"//xla/service/gpu/model:hlo_op_profile_proto_cc_impl",
"//xla/stream_executor:device_description_proto_cc_impl",
"//xla:autotune_results_proto_cc_impl",
"//xla/stream_executor:stream_executor_impl",
"//xla/stream_executor/gpu:gpu_init_impl",
"//xla/stream_executor/host:host_platform",
"//xla:literal",
"//xla:shape_util",
Expand All @@ -44,7 +49,6 @@ cc_binary(
"//xla/pjrt:pjrt_compiler",
"//xla/pjrt:tfrt_cpu_pjrt_client",
"//xla/pjrt:pjrt_c_api_client",
"//xla/pjrt:tpu_client",
"//xla/pjrt/distributed",
"//xla/pjrt/gpu:se_gpu_pjrt_client",
"//xla/pjrt/distributed:client",
Expand All @@ -61,8 +65,6 @@ cc_binary(
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:ReconcileUnrealizedCasts",
"@llvm-project//mlir:SparseTensorDialect",
"@tf_runtime//:core_runtime",
"@tf_runtime//:hostcontext",
"@tsl//tsl/platform:errors",
"@tsl//tsl/platform:fingerprint",
"@tsl//tsl/platform:float8",
Expand Down Expand Up @@ -90,7 +92,21 @@ cc_binary(
"//xla/stream_executor:rocm_platform"
]),
copts = ["-fvisibility=default"],
linkopts = ["-shared"],
linkopts= select({
"@tsl//tsl:macos": [
# We set the install_name, such that the library is looked up
# in the RPATH at runtime, otherwise the install_name is an
# arbitrary path within bazel workspace
"-Wl,-install_name,@rpath/libxla_extension.so",
# We set RPATH to the same dir as libxla_extension.so, so that
# loading PjRt plugins in the same directory works out of the box
"-Wl,-rpath,@loader_path/",
],
"//conditions:default": [
"-Wl,-soname,libxla_extension.so",
"-Wl,-rpath='$$ORIGIN'",
],
}),
features = ["-use_header_modules"],
linkshared = 1,
)
Expand Down Expand Up @@ -144,6 +160,9 @@ genrule(
d="$${d/include\\/grpc/grpc}"
# Remap tfrt paths
d="$${d/include\\/tfrt/tfrt}"
# Remap ml_dtypes paths
d="$${d/_virtual_includes\\/int4\\/ml_dtypes/ml_dtypes}"
d="$${d/_virtual_includes\\/float8\\/ml_dtypes/ml_dtypes}"
mkdir -p "$@/$${d}"
cp "$${f}" "$@/$${d}/"
done
Expand All @@ -153,16 +172,40 @@ genrule(
""",
)

genrule(
name = "libtpu_whl",
outs = ["libtpu.whl"],
cmd = """
libtpu_version="0.1.dev20231102"
libtpu_storage_path="https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/wheels/libtpu-nightly/libtpu_nightly-$${libtpu_version}-py3-none-any.whl"
wget -O "$@" "$$libtpu_storage_path"
"""
)

genrule(
name = "libtpu_so",
srcs = [
":libtpu_whl"
],
outs = ["libtpu.so"],
cmd = """
unzip -p "$(SRCS)" libtpu/libtpu.so > "$@"
"""
)

# This genrule remaps libxla_extension.so to lib/libxla_extension.so
genrule(
name = "xla_extension_lib",
srcs = [
":libxla_extension.so",
],
]
+ if_with_tpu_support([
":libtpu_so"
]),
outs = ["lib"],
cmd = """
mkdir $@
mv $(location :libxla_extension.so) $@
mv $(SRCS) $@
"""
)

Expand Down
7 changes: 6 additions & 1 deletion lib/xla.ex
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,12 @@ defmodule XLA do
]

"rocm" <> _ ->
["--config=rocm", "--action_env=HIP_PLATFORM=hcc"]
[
"--config=rocm",
"--action_env=HIP_PLATFORM=hcc",
# See https://github.com/google/jax/blob/c9cf6b44239e373cba384936dcfeff60e39ad560/.bazelrc#L80
~s/--action_env=TF_ROCM_AMDGPU_TARGETS="gfx900,gfx906,gfx908,gfx90a,gfx1030"/
]

"tpu" <> _ ->
["--config=tpu"]
Expand Down

0 comments on commit af010a2

Please sign in to comment.