Skip to content

Commit

Permalink
⚡ Don't install torch just to look for NodeBlockFunctions
Browse files Browse the repository at this point in the history
  • Loading branch information
shnizzedy committed Jan 9, 2025
1 parent 20ce428 commit 780600a
Showing 1 changed file with 26 additions and 5 deletions.
31 changes: 26 additions & 5 deletions CPAC/pipeline/resource_inventory.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,10 @@
from itertools import chain
import os
from pathlib import Path
from typing import Any, cast, Iterable
from typing import Any, cast, Iterable, Optional
from unittest.mock import patch

from traits.trait_errors import TraitError
import yaml

from CPAC.pipeline.engine import template_dataframe
Expand All @@ -36,15 +38,22 @@
from CPAC.utils.outputs import Outputs


def import_nodeblock_functions(package_name: str) -> list[NodeBlockFunction]:
def import_nodeblock_functions(
package_name: str, exclude: Optional[list[str]] = None
) -> list[NodeBlockFunction]:
"""
Import all functions with the @nodeblock decorator from all modules and submodules in a package.
Parameters
----------
package_name
The name of the package to import from.
exclude
A list of module names to exclude from the import.
"""
if exclude is None:
exclude = []
functions: list[NodeBlockFunction] = []
package = importlib.import_module(package_name)
package_path = package.__path__[0] # Path to the package directory
Expand All @@ -55,11 +64,16 @@ def import_nodeblock_functions(package_name: str) -> list[NodeBlockFunction]:
# Get the module path
rel_path = os.path.relpath(os.path.join(root, file), package_path)
module_name = f"{package_name}.{rel_path[:-3].replace(os.sep, '.')}"
if module_name in exclude:
continue

# Import the module
try:
module = importlib.import_module(module_name)
except ImportError as e:
with patch.dict(
"sys.modules", {exclusion: None for exclusion in exclude}
):
module = importlib.import_module(module_name)
except (ImportError, TraitError, ValueError) as e:
UTLOGGER.debug(f"Failed to import {module_name}: {e}")
continue
# Extract nodeblock-decorated functions from the module
Expand Down Expand Up @@ -258,7 +272,14 @@ def resource_inventory(package: str = "CPAC") -> dict[str, ResourceIO]:
"""Gather all inputs and outputs for a list of NodeBlockFunctions."""
resources: dict[str, ResourceIO] = {}
# Node block function inputs and outputs
for nbf in import_nodeblock_functions(package):
for nbf in import_nodeblock_functions(
package,
[
# No nodeblock functions in these modules that dynamically isntall torch
"CPAC.unet.__init__",
"CPAC.unet._torch",
],
):
nbf_name = f"{nbf.__module__}.{nbf.__qualname__}"
if hasattr(nbf, "inputs"):
for nbf_input in _flatten_io(cast(list[Iterable], nbf.inputs)):
Expand Down

0 comments on commit 780600a

Please sign in to comment.