Skip to content

Commit

Permalink
setup.py:
Browse files Browse the repository at this point in the history
  - Don't try to call ldconfig when it doesn't exist
  - Adjust Cuda versions to TF compiled versions
  - Multiple package install handling
  • Loading branch information
torzdf committed Dec 17, 2023
1 parent a62a85c commit 9d231bc
Showing 1 changed file with 43 additions and 22 deletions.
65 changes: 43 additions & 22 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,15 @@
"all": ["imageio-ffmpeg"]} # 17/11/23 Conda forge uses incorrect ffmpeg, so fallback to pip
# Revisions of tensorflow GPU and cuda/cudnn requirements. These relate specifically to the
# Tensorflow builds available from pypi
_TENSORFLOW_REQUIREMENTS = {">=2.10.0,<2.11.0": [">=11.0,<12.0", ">=8.0,<9.0"]}
_TENSORFLOW_REQUIREMENTS = {">=2.10.0,<2.11.0": [">=11.2,<11.3", ">=8.1,<8.2"]}
# ROCm min/max version requirements for Tensorflow
_TENSORFLOW_ROCM_REQUIREMENTS = {">=2.10.0,<2.11.0": ((5, 2, 0), (5, 4, 0))}
# TODO tensorflow-metal versioning

# Mapping of Python packages to their conda names if different from pip or in non-default channel
_CONDA_MAPPING: dict[str, tuple[str, str]] = {
"cudatoolkit": ("cudatoolkit", "conda-forge"),
"cudnn": ("cudnn", "conda-forge"),
"fastcluster": ("fastcluster", "conda-forge"),
"ffmpy": ("ffmpy", "conda-forge"),
# "imageio-ffmpeg": ("imageio-ffmpeg", "conda-forge"),
Expand Down Expand Up @@ -298,13 +300,14 @@ class Packages():
"""
def __init__(self, environment: Environment) -> None:
self._env = environment
self._conda_required_packages: list[tuple[str, ...]] = [("tk", ), ("git", )]
self._conda_required_packages: list[tuple[list[str] | str, str]] = [("tk", "default"),
("git", "default")]
self._update_backend_specific_conda()
self._installed_packages = self._get_installed_packages()
self._conda_installed_packages = self._get_installed_conda_packages()
self._required_packages: list[tuple[str, list[tuple[str, str]]]] = []
self._missing_packages: list[tuple[str, list[tuple[str, str]]]] = []
self._conda_missing_packages: list[tuple[str, ...]] = []
self._conda_missing_packages: list[tuple[list[str] | str, str]] = []

@property
def prerequisites(self) -> list[tuple[str, list[tuple[str, str]]]]:
Expand Down Expand Up @@ -333,7 +336,7 @@ def to_install(self) -> list[tuple[str, list[tuple[str, str]]]]:
return self._missing_packages

@property
def to_install_conda(self) -> list[tuple[str, ...]]:
def to_install_conda(self) -> list[tuple[list[str] | str, str]]:
""" list: The required conda packages that need to be installed """
return self._conda_missing_packages

Expand All @@ -351,6 +354,8 @@ def _update_backend_specific_conda(self) -> None:
logger.debug("No backend packages to add for '%s'. All optional packages: %s",
self._env.backend, _BACKEND_SPECIFIC_CONDA)
return

combined_cuda = []
for pkg in to_add:
pkg, channel = _CONDA_MAPPING.get(pkg, (pkg, ""))
if pkg == "zlib-wapi" and self._env.os_version[0].lower() != "windows":
Expand All @@ -359,14 +364,18 @@ def _update_backend_specific_conda(self) -> None:
if pkg in ("cudatoolkit", "cudnn"): # TODO Handle multiple cuda/cudnn requirements
idx = 0 if pkg == "cudatoolkit" else 1
pkg = f"{pkg}{list(_TENSORFLOW_REQUIREMENTS.values())[0][idx]}"
if pkg.startswith("cudnn"):
# We add cudnn first so that dependency resolver does not need to re-download cuda
# if an incompatible version was installed
self._conda_required_packages.insert(0, (pkg, channel))
else:
self._conda_required_packages.append((pkg, channel))
logger.debug("Adding conda required package '%s' for backend '%s')",
pkg, self._env.backend)

combined_cuda.append(pkg)
continue

self._conda_required_packages.append((pkg, channel))
logger.info("Adding conda required package '%s' for backend '%s')",
pkg, self._env.backend)

if combined_cuda:
self._conda_required_packages.append((combined_cuda, channel))
logger.info("Adding conda required package '%s' for backend '%s')",
combined_cuda, self._env.backend)

@classmethod
def _format_requirements(cls, packages: list[str]
Expand Down Expand Up @@ -770,7 +779,10 @@ def _rocm_check(self) -> None:
with ldconfig then attempt to find it in LD_LIBRARY_PATH. If found, set the
:attr:`rocm_version` to the discovered version
"""
chk = os.popen("ldconfig -p | grep -P \"librocm-core.so.\\d+\" | head -n 1").read()
ldconfig = os.popen("which ldconfig").read()
if not ldconfig:
return
chk = os.popen(f"{ldconfig} -p | grep -P \"librocm-core.so.\\d+\" | head -n 1").read()
if not chk and os.environ.get("LD_LIBRARY_PATH"):
for path in os.environ["LD_LIBRARY_PATH"].split(":"):
chk = os.popen(f"ls {path} | grep -P -o \"librocmcore.so.\\d+\" | "
Expand Down Expand Up @@ -841,7 +853,10 @@ def _cuda_check(self) -> None:
def _cuda_check_linux(self) -> None:
""" For Linux check the dynamic link loader for libcudart. If not found with ldconfig then
attempt to find it in LD_LIBRARY_PATH. """
chk = os.popen("ldconfig -p | grep -P \"libcudart.so.\\d+.\\d+\" | head -n 1").read()
ldconfig = os.popen("which ldconfig").read()
if not ldconfig:
return
chk = os.popen(f"{ldconfig} -p | grep -P \"libcudart.so.\\d+.\\d+\" | head -n 1").read()
if not chk and os.environ.get("LD_LIBRARY_PATH"):
for path in os.environ["LD_LIBRARY_PATH"].split(":"):
chk = os.popen(f"ls {path} | grep -P -o \"libcudart.so.\\d+.\\d+\" | "
Expand Down Expand Up @@ -898,7 +913,10 @@ def _cudnn_check(self) -> None:
if self._os == "windows":
return

chk = os.popen("ldconfig -p | grep -P \"libcudnn.so.\" | head -n 1").read()
ldconfig = os.popen("which ldconfig").read()
if not ldconfig:
return
chk = os.popen(f"{ldconfig} -p | grep -P \"libcudnn.so.\" | head -n 1").read()
if not chk:
return
cudnnvers = chk.strip().replace("libcudnn.so.", "").split()[0]
Expand Down Expand Up @@ -1078,15 +1096,15 @@ def _install_missing_dep(self) -> None:
self._install_python_packages()

def _from_conda(self,
package: str,
package: list[str] | str,
channel: str = "",
conda_only: bool = False) -> bool:
""" Install a conda package
Parameters
----------
package: str
The full formatted package, with version, to be installed
package: list[str] | str
The full formatted package(s), with version(s), to be installed
channel: str, optional
The Conda channel to install from. Select empty string for default channel.
Default: ``""`` (empty string)
Expand All @@ -1104,11 +1122,14 @@ def _from_conda(self,
if channel:
condaexe.extend(["-c", channel])

if any(char in package for char in (" ", "<", ">", "*", "|")):
package = f"\"{package}\""
condaexe.append(package)
pkgs = package if isinstance(package, list) else [package]

for i, pkg in enumerate(pkgs):
if any(char in pkg for char in (" ", "<", ">", "*", "|")):
pkgs[i] = f"\"{pkg}\""
condaexe.extend(pkgs)

clean_pkg = package.replace("\"", "")
clean_pkg = " ".join([p.replace("\"", "") for p in pkgs])
installer = self._installer(self._env, clean_pkg, condaexe, self._is_gui)
retcode = installer()

Expand Down

0 comments on commit 9d231bc

Please sign in to comment.