Skip to content

Commit

Permalink
more tweaks
Browse files Browse the repository at this point in the history
  • Loading branch information
rpetit3 committed Jan 15, 2025
1 parent ee0b93b commit 063971e
Show file tree
Hide file tree
Showing 4 changed files with 170 additions and 37 deletions.
91 changes: 91 additions & 0 deletions dragonflye/dependency.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
from typing import Optional
from dataclasses import dataclass
from shutil import which
import subprocess
import re

from dragonflye.exceptions import (
InvalidDependencyError,
MissingDependencyError,
InvalidDependencyVersionError,
)

from packaging.version import parse, InvalidVersion, Version


@dataclass
class Dependency:
name: str
citation: Optional[str] = None
version_pattern: str = r"(\d+\.\d+\.\d+)" # Regex pattern to extract version
version_arg: Optional[str] = "--version"
version: Optional[str] = None
min_version: Optional[str] = None
max_version: Optional[str] = None
less_then: Optional[str] = None

def check(self):
if not which(self.name):
raise MissingDependencyError(
f"Could not find dependency {self.name}! Please install it."
)
return self._base_validator()

def format_version_requirements(self):
requirements = []
if self.version:
requirements.append(f"={self.version}")
if self.min_version:
requirements.append(f">={self.min_version}")
if self.max_version:
requirements.append(f"<={self.max_version}")
if self.less_then:
requirements.append(f"<{self.less_then}")
if not requirements:
return self.name
return f'{self.name} {",".join(requirements)}'

def _base_validator(self) -> Version:
cmd = [self.name]
if self.version_arg:
cmd.append(self.version_arg)

result = subprocess.run(
cmd, stderr=subprocess.STDOUT, stdout=subprocess.PIPE, text=True
).stdout

# Apply regex pattern from version_extractor to extract version
match = re.search(self.version_pattern, result)
if not match:
raise InvalidDependencyVersionError(
f"Could not extract version from '{result}' for {self.name}."
)

version = match.group(0) # Get the matched version string

try:
parsed_version = parse(version)
except InvalidVersion:
raise InvalidDependencyVersionError(
f"Could not parse version '{version}' for {self.name}."
)

# Validate the extracted version against the given constraints
if self.version and parsed_version != parse(self.version):
raise InvalidDependencyError(
f"{self.name} version must be {self.version} (found {parsed_version})"
)
if self.min_version and parsed_version < parse(self.min_version):
raise InvalidDependencyError(
f"{self.name} minimum version allowed is {self.min_version} (found {parsed_version})"
)
if self.max_version and parsed_version > parse(self.max_version):
raise InvalidDependencyError(
f"{self.name} maximum version allowed is {self.max_version} (found {parsed_version})"
)
if self.less_then and parsed_version >= parse(self.less_then):
raise InvalidDependencyError(
f"{self.name} version must be less than {self.less_then} (found {parsed_version})"
)

return parsed_version
14 changes: 14 additions & 0 deletions dragonflye/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
class DependencyError(Exception):
pass

class MissingDependencyError(DependencyError):
pass

class InvalidDependencyError(DependencyError):
pass

class InvalidDependencyVersionError(DependencyError):
pass

class MissingInputError(Exception):
pass
48 changes: 11 additions & 37 deletions dragonflye/tools/assemblyscan.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from dragonflye.dependencies import Dependency
from dragonflye.logging import Logger
from dragonflye.tools.base import BaseTool
from dragonflye.utils import execute, parse_version, which


class AssemblyScan(object):
class AssemblyScan(BaseTool):

def __init__(self, silent=False, verbose=False, show_time=False, show_level=False):
self.programs = {
Expand All @@ -21,21 +23,14 @@ def __init__(self, silent=False, verbose=False, show_time=False, show_level=Fals
show_level=show_level,
)

def check(self) -> bool:
"""
Check if the tools are installed and available in the PATH.
Returns:
bool: True if all tools are available, False otherwise.
"""
checks_passed = 0
for program in self.programs.keys():
success, program_path = which(program)
if success:
self.log.debug(f"{program} found: {program_path}")
self.programs[program]["path"] = program_path
checks_passed += 1
return False if checks_passed != len(self.programs) else True
_dependencies = [
Dependency(
name="assembly-scan",
min_version="1.0.0",
version_cmd="assembly-scan --version 2>&1",
version_pattern=r"^.*assembly-scan (.*)$",,
)
]

def run(self, input: str, output: str, args: dict = None, cwd: str = None):
"""
Expand All @@ -55,24 +50,3 @@ def run(self, input: str, output: str, args: dict = None, cwd: str = None):
cwd=cwd,
stdout=output,
)

def version(self) -> dict:
"""
Get the version of the tools.
Returns:
dict: Dictionary of tools and their versions.
"""
versions = {}
for program, values in self.programs.items():
e = execute(
values["version_cmd"],
stderr_handler=self.log.error,
stdout_handler=self.log.info,
)
self.programs[program]["version"] = parse_version(
e["stdout"][0], values["version_regex"]
)
versions[program] = self.programs[program]["version"]
self.log.info(f"{program}: {self.programs[program]['version']}")
return versions
54 changes: 54 additions & 0 deletions dragonflye/tools/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from abc import abstractmethod
from typing import List

from pydantic import BaseModel

from dragonflye.dependencies import Dependency
from dragonflye.utils import execute, parse_version, which


class BaseOutput(BaseModel):
pass


class BaseTool(BaseModel):
_dependencies: List[Dependency] = []

def check(self) -> bool:
"""
Check if the tools are installed and available in the PATH.
Returns:
bool: True if all tools are available, False otherwise.
"""
checks_passed = 0
for program in self.programs.keys():
success, program_path = which(program)
if success:
self.log.debug(f"{program} found: {program_path}")
self.programs[program]["path"] = program_path
checks_passed += 1
return False if checks_passed != len(self.programs) else True

def version(self) -> dict:
"""
Get the version of the tools.
Returns:
dict: Dictionary of tools and their versions.
"""
versions = {}
for program, values in self.programs.items():
e = execute(
values["version_cmd"],
stderr_handler=self.log.error,
stdout_handler=self.log.info,
max_lines=1,
ignore_truncation=True,
)
self.programs[program]["version"] = parse_version(
e["stdout"][0], values["version_regex"]
)
versions[program] = self.programs[program]["version"]
self.log.info(f"{program}: {self.programs[program]['version']}")
return versions

0 comments on commit 063971e

Please sign in to comment.