Skip to content

Commit

Permalink
Merge branch 'master' into fs3
Browse files Browse the repository at this point in the history
  • Loading branch information
torzdf committed Mar 1, 2024
2 parents 0a74a61 + d1dfce8 commit e3ba5ea
Showing 1 changed file with 48 additions and 33 deletions.
81 changes: 48 additions & 33 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,6 +673,44 @@ def _check_rocm(self) -> None:
_INSTALL_FAILED = True


def _check_ld_config(lib: str) -> str:
""" Locate a library in ldconfig
Parameters
----------
lib: str The library to locate
Returns
-------
str
The library from ldconfig, or empty string if not found
"""
retval = ""
ldconfig = which("ldconfig")
if not ldconfig:
return retval

retval = next((line.decode("utf-8", errors="replace").strip()
for line in run([ldconfig, "-p"],
capture_output=True,
check=False).stdout.splitlines()
if lib.encode("utf-8") in line), "")

if retval or (not retval and not os.environ.get("LD_LIBRARY_PATH")):
return retval

for path in os.environ["LD_LIBRARY_PATH"].split(":"):
if not path:
continue

retval = next((fname.strip() for fname in reversed(os.listdir(path))
if lib in fname), "")
if retval:
break

return retval


class ROCmCheck(): # pylint:disable=too-few-public-methods
""" Find the location of system installed ROCm on Linux """
# TODO
Expand All @@ -694,16 +732,7 @@ def _rocm_check(self) -> None:
with ldconfig then attempt to find it in LD_LIBRARY_PATH. If found, set the
:attr:`rocm_version` to the discovered version
"""
ldconfig = os.popen("which ldconfig").read()
if not ldconfig:
return
chk = os.popen(f"{ldconfig} -p | grep -P \"librocm-core.so.\\d+\" | head -n 1").read()
if not chk and os.environ.get("LD_LIBRARY_PATH"):
for path in os.environ["LD_LIBRARY_PATH"].split(":"):
chk = os.popen(f"ls {path} | grep -P -o \"librocmcore.so.\\d+\" | "
"head -n 1").read()
if chk:
break
chk = _check_ld_config("librocm-core.so.")
if not chk:
return

Expand Down Expand Up @@ -750,8 +779,7 @@ def _cuda_check(self) -> None:
stdout.decode(locale.getpreferredencoding(), errors="ignore"))
if version is not None:
self.cuda_version = version.groupdict().get("cuda", None)
locate = "where" if self._os == "windows" else "which"
path = os.popen(f"{locate} nvcc").read()
path = which("nvcc")
if path:
path = path.split("\n")[0] # Split multiple entries and take first found
while True: # Get Cuda root folder
Expand All @@ -768,22 +796,15 @@ def _cuda_check(self) -> None:
def _cuda_check_linux(self) -> None:
""" For Linux check the dynamic link loader for libcudart. If not found with ldconfig then
attempt to find it in LD_LIBRARY_PATH. """
ldconfig = os.popen("which ldconfig").read()
if not ldconfig:
return
chk = os.popen(f"{ldconfig} -p | grep -P \"libcudart.so.\\d+.\\d+\" | head -n 1").read()
if not chk and os.environ.get("LD_LIBRARY_PATH"):
for path in os.environ["LD_LIBRARY_PATH"].split(":"):
chk = os.popen(f"ls {path} | grep -P -o \"libcudart.so.\\d+.\\d+\" | "
"head -n 1").read()
if chk:
break
chk = _check_ld_config("libcudart.so.")
if not chk: # Cuda not found
return

cudavers = chk.strip().replace("libcudart.so.", "")
self.cuda_version = cudavers[:cudavers.find(" ")]
self.cuda_path = chk[chk.find("=>") + 3:chk.find("targets") - 1]
self.cuda_version = cudavers[:cudavers.find(" ")] if " " in cudavers else cudavers
cuda_path = chk[chk.find("=>") + 3:chk.find("targets") - 1]
if os.path.exists(cuda_path):
self.cuda_path = cuda_path

def _cuda_check_windows(self) -> None:
""" Check Windows CUDA Version and path from Environment Variables"""
Expand Down Expand Up @@ -828,10 +849,7 @@ def _cudnn_check(self) -> None:
if self._os == "windows":
return

ldconfig = os.popen("which ldconfig").read()
if not ldconfig:
return
chk = os.popen(f"{ldconfig} -p | grep -P \"libcudnn.so.\" | head -n 1").read()
chk = _check_ld_config("libcudnn.so.")
if not chk:
return
cudnnvers = chk.strip().replace("libcudnn.so.", "").split()[0]
Expand All @@ -850,10 +868,7 @@ def _get_checkfiles_linux(self) -> list[str]:
list
List of header file locations to scan for cuDNN versions
"""
ldconfig = os.popen("which ldconfig").read()
if not ldconfig:
return []
chk = os.popen(f"{ldconfig} -p | grep -P \"libcudnn.so.\\d+\" | head -n 1").read()
chk = _check_ld_config("libcudnn.so.")
chk = chk.strip().replace("libcudnn.so.", "")
if not chk:
return []
Expand All @@ -876,7 +891,7 @@ def _get_checkfiles_windows(self) -> list[str]:
List of header file locations to scan for cuDNN versions
"""
# TODO A more reliable way of getting the windows location
if not self.cuda_path:
if not self.cuda_path or not os.path.exists(self.cuda_path):
return []
scandir = os.path.join(self.cuda_path, "include")
cudnn_checkfiles = [os.path.join(scandir, header) for header in self._cudnn_header_files]
Expand Down

0 comments on commit e3ba5ea

Please sign in to comment.