Skip to content

Commit

Permalink
Fix libhib discovery for Windows (#722)
Browse files Browse the repository at this point in the history
  • Loading branch information
pxl-th authored Jan 11, 2025
1 parent c8a5309 commit 0659e31
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 107 deletions.
52 changes: 19 additions & 33 deletions src/discovery/discovery.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
module ROCmDiscovery

export lld_artifact, lld_path, libhsaruntime, libdevice_libs, libhip
export librocblas, librocsparse, librocsolver
export librocrand, librocfft, libMIOpen_path

using LLD_jll
using ROCmDeviceLibs_jll
using Preferences
using Libdl
Expand All @@ -12,18 +17,11 @@ function get_artifact_library(pkg::Symbol, libname::Symbol)::String
return res
end

function get_library(libname::String; rocm_path::String, ext::String = dlext)
find_rocm_library(libname, rocm_path, ext)
end

function get_ld_lld(rocm_path::String;
from_artifact::Bool, artifact_library::Symbol, artifact_field::Symbol,
)
if from_artifact
get_artifact_library(artifact_library, artifact_field)
else
find_ld_lld(rocm_path)
end
function get_ld_lld(rocm_path::String)::Tuple{String, Bool}
lld_path = find_ld_lld(rocm_path)
isempty(lld_path) || return (lld_path, false)
LLD_jll.is_available() || return (lld_path, false)
return (LLD_jll.lld_path, true)
end

function get_device_libs(from_artifact::Bool; rocm_path::String)
Expand All @@ -34,11 +32,6 @@ function get_device_libs(from_artifact::Bool; rocm_path::String)
end
end

export lld_artifact, lld_path, libhsaruntime, libdevice_libs, libhip
export librocblas, librocsparse, librocsolver
export librocrand, librocfft, libMIOpen_path
export julia_exeflags

function _hip_runtime_version()
v_ref = Ref{Cint}()
res = ccall((:hipRuntimeGetVersion, libhip), UInt32, (Ptr{Cint},), v_ref)
Expand Down Expand Up @@ -82,23 +75,16 @@ function __init__()

try
global libhsaruntime = Sys.islinux() ?
get_library("libhsa-runtime64"; rocm_path, ext="so.1") :
find_rocm_library("libhsa-runtime64"; rocm_path, ext="so.1") :
""

# Linker.
lld_path = get_ld_lld(rocm_path; from_artifact=false,
artifact_library=:LLD_jll, artifact_field=:lld_path)
lld_artifact = false
if isempty(lld_path)
lld_path = get_ld_lld(rocm_path; from_artifact=true,
artifact_library=:LLD_jll, artifact_field=:lld_path)
lld_artifact = true
end
lld_path, lld_artifact = get_ld_lld(rocm_path)
global lld_path = lld_path
global lld_artifact = lld_artifact

# HIP.
global libhip = get_library(Sys.islinux() ? "libamdhip64" : "amdhip64"; rocm_path)
global libhip = find_rocm_library(Sys.islinux() ? "libamdhip64" : "amdhip64"; rocm_path)

# Turn off artifacts for device libraries if:
# - Julia 1.12+;
Expand All @@ -111,12 +97,12 @@ function __init__()
global libdevice_libs = get_device_libs(from_artifact; rocm_path)

# HIP-based libraries.
global librocblas = get_library(lib_prefix * "rocblas"; rocm_path)
global librocsparse = get_library(lib_prefix * "rocsparse"; rocm_path)
global librocsolver = get_library(lib_prefix * "rocsolver"; rocm_path)
global librocrand = get_library(lib_prefix * "rocrand"; rocm_path)
global librocfft = get_library(lib_prefix * "rocfft"; rocm_path)
global libMIOpen_path = get_library(lib_prefix * "MIOpen"; rocm_path)
global librocblas = find_rocm_library(lib_prefix * "rocblas"; rocm_path)
global librocsparse = find_rocm_library(lib_prefix * "rocsparse"; rocm_path)
global librocsolver = find_rocm_library(lib_prefix * "rocsolver"; rocm_path)
global librocrand = find_rocm_library(lib_prefix * "rocrand"; rocm_path)
global librocfft = find_rocm_library(lib_prefix * "rocfft"; rocm_path)
global libMIOpen_path = find_rocm_library(lib_prefix * "MIOpen"; rocm_path)
catch err
@error """ROCm discovery failed!
Discovered ROCm path: $rocm_path.
Expand Down
69 changes: 8 additions & 61 deletions src/discovery/utils.jl
Original file line number Diff line number Diff line change
@@ -1,57 +1,3 @@
function detect_projects()
amdgpu_project = normpath(joinpath(@__DIR__, "..", ".."))
current_project = Base.ACTIVE_PROJECT[]
julia_project = if Base.JLOptions().project != C_NULL
unsafe_string(Base.JLOptions().project)
elseif current_project !== nothing
current_project
else
amdgpu_project
end
return (;amdgpu_project, current_project, julia_project)
end

julia_exeflags(projects = detect_projects()) = String[
"--startup-file=no",
"--project=$(projects.julia_project)",
"--threads=$(Threads.nthreads())"]

function julia_cmd_projects(jl_str)
projects = detect_projects()

cmd = Base.julia_cmd()
append!(cmd.exec, julia_exeflags(projects))

(;amdgpu_project, current_project, julia_project) = projects

if current_project !== nothing
current_project = replace(current_project, "\\" => "/")
jl_str = "push!(LOAD_PATH, \"$current_project\");" * jl_str
else
# If Julia is using global project, instantiate `julia_project`.
# Otherwise, we'll fail to discover artifacts.
jl_str = "import Pkg; Pkg.instantiate(;io=devnull); " * jl_str
end
amdgpu_project = replace(amdgpu_project, "\\" => "/")
jl_str = "push!(LOAD_PATH, \"$amdgpu_project\");" * jl_str
append!(cmd.exec, ("-e", jl_str))
return cmd
end

function safe_exec(str)
cmd = julia_cmd_projects(str)
success = false
res_str, err_str = mktemp() do path, _
p = run(pipeline(cmd; stdout=path, stderr="$path.err"); wait=false)
wait(p)
success = p.exitcode == 0
res_str = strip(String(read(path)))
err_str = strip(String(read("$path.err")))
res_str, err_str
end
return success, res_str
end

"""
Find root ROCm directory.
"""
Expand All @@ -78,7 +24,7 @@ function find_roc_path()::String
return ""
end

function find_device_libs(rocm_path::String)
function find_device_libs(rocm_path::String)::String
env_dir = get(ENV, "ROCM_PATH", "")
if isdir(env_dir)
path = joinpath(env_dir, "amdgcn", "bitcode")
Expand All @@ -97,26 +43,27 @@ function find_device_libs(rocm_path::String)
return ""
end

function find_rocm_library(libs::Vector, rocm_path::String, ext::String = dlext)
function find_rocm_library(libs::Vector; rocm_path::String, ext::String = dlext)::String
for lib in libs
path = find_rocm_library(lib, rocm_path, ext)
isempty(path) || return path
end
return ""
end

function find_rocm_library(lib::String, rocm_path::String, ext::String = dlext)
function find_rocm_library(lib::String; rocm_path::String, ext::String = dlext)::String
libdir = joinpath(rocm_path, Sys.islinux() ? "lib" : "bin")
isdir(libdir) || return ""

for file in readdir(libdir)
matched = startswith(basename(file), lib * ".$ext")
matched && return joinpath(libdir, file)
for file in readdir(libdir; join=true)
fname = basename(file)
matched = startswith(fname, lib) && endswith(fname, ext)
matched && return file
end
return ""
end

function find_ld_lld(rocm_path::String)
function find_ld_lld(rocm_path::String)::String
lld_name = "ld.lld" * (Sys.iswindows() ? ".exe" : "")
for subdir in (joinpath("llvm", "bin"), "bin")
exp_ld_path = joinpath(rocm_path, subdir, lld_name)
Expand Down
2 changes: 1 addition & 1 deletion src/runtime/Runtime.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import Preferences: @load_preference, @set_preferences!
import ..HSA
import ..HIP
import ..AMDGPU
import ..AMDGPU: getinfo, LockedObject
import ..AMDGPU: LockedObject
import .HIP: HIPDevice

struct Adaptor end
Expand Down
12 changes: 0 additions & 12 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -114,18 +114,6 @@ function print_build_diagnostics()
run(`id`)
end

function getinfo(::Type{String}, object, query)::String
value = Base.zeros(UInt8, 64)
getinfo(object, query, value) |> Runtime.check
return rstrip(String(value), '\0')
end

function getinfo(::Type{T}, object, query)::T where T
value = Ref{T}()
getinfo(object, query, value) |> Runtime.check
return value[]
end

function check end

macro check(f)
Expand Down

0 comments on commit 0659e31

Please sign in to comment.