diff --git a/.gitignore b/.gitignore index bd3acbce..338c779e 100644 --- a/.gitignore +++ b/.gitignore @@ -10,3 +10,4 @@ __pycache__ /staramr/databases/data/ /.eggs /.venv +/.mypy_cache diff --git a/.mypy.ini b/.mypy.ini new file mode 100644 index 00000000..1a6fb910 --- /dev/null +++ b/.mypy.ini @@ -0,0 +1,15 @@ +[mypy] +python_version = 3.5 +warn_unused_configs = False + +[mypy-pandas.*] +ignore_missing_imports = True + +[mypy-Bio.*] +ignore_missing_imports = True + +[mypy-git.*] +ignore_missing_imports = True + +[mypy-numpy.*] +ignore_missing_imports = True diff --git a/.travis.yml b/.travis.yml index e15ba9c5..a8ad2a3f 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,6 +1,7 @@ sudo: required language: python python: + - "3.5" - "3.6" env: @@ -18,7 +19,7 @@ install: - conda create -c bioconda -c conda-forge -q -y -n test-environment python=$TRAVIS_PYTHON_VERSION blast=2.7.1 git - source activate test-environment - python setup.py install - - mkdir -p staramr/databases/data/ - - staramr db build --dir staramr/databases/data/update $DATABASE_COMMITS + - staramr db build --dir staramr/databases/data $DATABASE_COMMITS + - pip install mypy==0.600 -script: python setup.py test +script: ./scripts/mypy && python setup.py test diff --git a/CHANGELOG.md b/CHANGELOG.md index 54860f61..191291ff 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,11 @@ +# Version 0.4.0 + +* Add support for campylobacter from PointFinder database. +* Fix `read_table` deprecation warnings by replacing `read_table` with `read_csv`. +* Handling issue with name of `16S` gene in PointFinder database for salmonella. +* Refactoring and simplifying some of the git ResFinder/PointFinder database code. +* Added automated type checking with [mypy](https://mypy.readthedocs.io). + # Version 0.3.0 * Exclusion of `aac(6')-Iaa` from results by default. Added ability to override this with `--no-exclude-genes` or pass a custom list of genes to exclude from results with `--exclude-genes-file`. diff --git a/README.md b/README.md index 23d31dc9..724e1b90 100644 --- a/README.md +++ b/README.md @@ -90,7 +90,7 @@ To include acquired point-mutation resistances using PointFinder, please run: staramr search --pointfinder-organism salmonella -o out *.fasta ``` -Where `--pointfinder-organism` is the specific organism you are interested in (currently only *salmonella* is supported). +Where `--pointfinder-organism` is the specific organism you are interested in (currently only *salmonella* and *campylobacter* are supported). ## Database Info @@ -194,7 +194,7 @@ staramr db restore-default ## Dependencies -* Python 3 +* Python 3.5+ * BLAST+ * Git @@ -386,7 +386,7 @@ positional arguments: optional arguments: -h, --help show this help message and exit --pointfinder-organism POINTFINDER_ORGANISM - The organism to use for pointfinder {salmonella}. Defaults to disabling search for point mutations. [None]. + The organism to use for pointfinder {salmonella, campylobacter}. Defaults to disabling search for point mutations. [None]. -d DATABASE, --database DATABASE The directory containing the resfinder/pointfinder databases [staramr/databases/data]. -n NPROCS, --nprocs NPROCS @@ -528,7 +528,7 @@ Example: # Caveats -This software is still a work-in-progress. In particular, not all organisms stored in the PointFinder database are supported (only *salmonella* is currently supported). Additionally, the predicted phenotypes are for microbiological resistance and *not* clinical resistance. Phenotype/drug resistance predictions are an experimental feature which is continually being improved. +This software is still a work-in-progress. In particular, not all organisms stored in the PointFinder database are supported (only *salmonella* and *campylobacter* are currently supported). Additionally, the predicted phenotypes are for microbiological resistance and *not* clinical resistance. Phenotype/drug resistance predictions are an experimental feature which is continually being improved. `staramr` only works on assembled genomes and not directly on reads. A quick genome assembler you could use is [Shovill][shovill]. Or, you may also wish to try out the [ResFinder webservice][resfinder-web], or the command-line tools [rgi][] or [ariba][] which will work on sequence reads as well as genome assemblies. You may also wish to check out the [CARD webservice][card-web]. diff --git a/bin/staramr b/bin/staramr index 1035891c..b0272d79 100755 --- a/bin/staramr +++ b/bin/staramr @@ -67,10 +67,10 @@ if __name__ == '__main__': try: args.run_command(args) except CommandParseException as e: - logger.error(e) + logger.error(str(e)) if e.print_help(): e.get_parser().print_help() sys.exit(1) except Exception as e: - logger.exception(e) + logger.exception(str(e)) sys.exit(1) diff --git a/scripts/mypy b/scripts/mypy new file mode 100755 index 00000000..dc1fe277 --- /dev/null +++ b/scripts/mypy @@ -0,0 +1,6 @@ +#!/bin/bash + +SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" +ROOT_DIR="$SCRIPT_DIR/.." + +mypy --config $ROOT_DIR/.mypy.ini $ROOT_DIR/bin/staramr $ROOT_DIR/staramr diff --git a/staramr/__init__.py b/staramr/__init__.py index 0404d810..abeeedbf 100644 --- a/staramr/__init__.py +++ b/staramr/__init__.py @@ -1 +1 @@ -__version__ = '0.3.0' +__version__ = '0.4.0' diff --git a/staramr/blast/AbstractBlastDatabase.py b/staramr/blast/AbstractBlastDatabase.py index 3469479a..4927cb0c 100644 --- a/staramr/blast/AbstractBlastDatabase.py +++ b/staramr/blast/AbstractBlastDatabase.py @@ -33,6 +33,14 @@ def get_path(self, database_name): """ pass + @abc.abstractmethod + def get_name(self) -> str: + """ + Gets a name for this blast database implementation. + :return: A name for this implementation. + """ + pass + def get_database_paths(self): """ Gets a list of all database paths. diff --git a/staramr/blast/BlastHandler.py b/staramr/blast/BlastHandler.py index 791100c8..cceb262e 100644 --- a/staramr/blast/BlastHandler.py +++ b/staramr/blast/BlastHandler.py @@ -3,9 +3,11 @@ import subprocess from concurrent.futures import ThreadPoolExecutor from os import path +from typing import Dict from Bio.Blast.Applications import NcbiblastnCommandline +from staramr.blast.AbstractBlastDatabase import AbstractBlastDatabase from staramr.exceptions.BlastProcessError import BlastProcessError logger = logging.getLogger('BlastHandler') @@ -32,16 +34,14 @@ class BlastHandler: qseq '''.strip().split('\n')] - def __init__(self, resfinder_database, threads, output_directory, pointfinder_database=None): + def __init__(self, blast_database_objects_map: Dict[str, AbstractBlastDatabase], threads: int, + output_directory: str) -> None: """ Creates a new BlastHandler. - :param resfinder_database: The staramr.blast.resfinder.ResfinderBlastDatabase for the particular ResFinder database. + :param blast_database_objects_map: A map containing the blast databases. :param threads: The maximum number of threads to use, where one BLAST process gets assigned to one thread. :param output_directory: The output directory to store BLAST results. - :param pointfinder_database: The staramr.blast.pointfinder.PointfinderBlastDatabase to use for the particular PointFinder database. """ - self._resfinder_database = resfinder_database - if threads is None: raise Exception("threads is None") @@ -53,11 +53,13 @@ def __init__(self, resfinder_database, threads, output_directory, pointfinder_da self._output_directory = output_directory self._input_genomes_tmp_dir = path.join(output_directory, 'input-genomes') - if (pointfinder_database == None): - self._pointfinder_configured = False + self._blast_database_objects_map = blast_database_objects_map + + if (self._blast_database_objects_map['pointfinder'] is None): + self._pointfinder_configured = False # type: bool + del self._blast_database_objects_map['pointfinder'] else: - self._pointfinder_database = pointfinder_database - self._pointfinder_configured = True + self._pointfinder_configured = True # type: bool self._thread_pool_executor = None self.reset() @@ -70,10 +72,8 @@ def reset(self): if self._thread_pool_executor is not None: self._thread_pool_executor.shutdown() self._thread_pool_executor = ThreadPoolExecutor(max_workers=self._threads) - self._resfinder_blast_map = {} - self._pointfinder_blast_map = {} - self._pointfinder_future_blasts = [] - self._resfinder_future_blasts = [] + self._blast_map = {} + self._future_blasts_map = {} if path.exists(self._input_genomes_tmp_dir): logger.debug("Directory [%s] already exists", self._input_genomes_tmp_dir) @@ -86,23 +86,15 @@ def run_blasts(self, files): :param files: The files to scan. :return: None """ - database_names_resfinder = self._resfinder_database.get_database_names() - logger.debug("Resfinder Databases: %s", database_names_resfinder) - - if self.is_pointfinder_configured(): - database_names_pointfinder = self._pointfinder_database.get_database_names() - logger.debug("Pointfinder Databases: %s", database_names_pointfinder) - else: - database_names_pointfinder = None - db_files = self._make_db_from_input_files(self._input_genomes_tmp_dir, files) logger.debug("Done making blast databases for input files") for file in db_files: - logger.info("Scheduling blast for %s", path.basename(file)) - self._schedule_resfinder_blast(file, database_names_resfinder) - if self.is_pointfinder_configured(): - self._schedule_pointfinder_blast(file, database_names_pointfinder) + logger.info("Scheduling blasts for %s", path.basename(file)) + + for name in self._blast_database_objects_map: + database_object = self._blast_database_objects_map[name] + self._schedule_blast(file, database_object) def _make_db_from_input_files(self, db_dir, files): logger.info("Making BLAST databases for input files") @@ -126,33 +118,34 @@ def _make_db_from_input_files(self, db_dir, files): return db_files - def _schedule_resfinder_blast(self, file, database_names): + def _schedule_blast(self, file, blast_database): + database_names = blast_database.get_database_names() + logger.debug("%s databases: %s", blast_database.get_name(), database_names) for database_name in database_names: - database = self._resfinder_database.get_path(database_name) + database = blast_database.get_path(database_name) file_name = os.path.basename(file) - blast_out = os.path.join(self._output_directory, file_name + "." + database_name + ".resfinder.blast.xml") + blast_out = os.path.join(self._output_directory, + file_name + "." + database_name + "." + blast_database.get_name() + ".blast.tsv") if os.path.exists(blast_out): raise Exception("Error, blast_out [%s] already exists", blast_out) - self._resfinder_blast_map.setdefault(file_name, {})[database_name] = blast_out + self._get_blast_map(blast_database.get_name()).setdefault(file_name, {})[database_name] = blast_out future_blast = self._thread_pool_executor.submit(self._launch_blast, database, file, blast_out) - self._resfinder_future_blasts.append(future_blast) + self._get_future_blasts_from_map(blast_database.get_name()).append(future_blast) - def _schedule_pointfinder_blast(self, file, database_names): - for database_name in database_names: - database = self._pointfinder_database.get_path(database_name) - file_name = os.path.basename(file) + def _get_blast_map(self, name): + if name not in self._blast_map: + self._blast_map[name] = {} - blast_out = os.path.join(self._output_directory, file_name + "." + database_name + ".pointfinder.blast.xml") - if os.path.exists(blast_out): - raise Exception("Error, blast_out [%s] already exists", blast_out) + return self._blast_map[name] - self._pointfinder_blast_map.setdefault(file_name, {})[database_name] = blast_out + def _get_future_blasts_from_map(self, name): + if name not in self._future_blasts_map: + self._future_blasts_map[name] = [] - future_blast = self._thread_pool_executor.submit(self._launch_blast, database, file, blast_out) - self._pointfinder_future_blasts.append(future_blast) + return self._future_blasts_map[name] def is_pointfinder_configured(self): """ @@ -169,9 +162,9 @@ def get_resfinder_outputs(self): """ # Forces any exceptions to be thrown if error with blasts - for future_blast in self._resfinder_future_blasts: + for future_blast in self._get_future_blasts_from_map('resfinder'): future_blast.result() - return self._resfinder_blast_map + return self._get_blast_map('resfinder') def get_pointfinder_outputs(self): """ @@ -181,9 +174,9 @@ def get_pointfinder_outputs(self): """ if (self.is_pointfinder_configured()): # Forces any exceptions to be thrown if error with blasts - for future_blast in self._pointfinder_future_blasts: + for future_blast in self._get_future_blasts_from_map('pointfinder'): future_blast.result() - return self._pointfinder_blast_map + return self._get_blast_map('pointfinder') else: raise Exception("Error, pointfinder has not been configured") diff --git a/staramr/blast/pointfinder/PointfinderBlastDatabase.py b/staramr/blast/pointfinder/PointfinderBlastDatabase.py index 94ccdf56..14e2f80e 100644 --- a/staramr/blast/pointfinder/PointfinderBlastDatabase.py +++ b/staramr/blast/pointfinder/PointfinderBlastDatabase.py @@ -71,13 +71,16 @@ def get_organism(self): """ return self.organism + def get_name(self): + return 'pointfinder' + @classmethod def get_available_organisms(cls): """ A Class Method to get a list of organisms that are currently supported by staramr. :return: The list of organisms currently supported by staramr. """ - return ['salmonella'] + return ['salmonella','campylobacter'] @classmethod def get_organisms(cls, database_dir): @@ -86,7 +89,7 @@ def get_organisms(cls, database_dir): :param database_dir: The PointFinder database root directory. :return: A list of organisms. """ - config = pd.read_table(path.join(database_dir, 'config'), comment='#', header=None, + config = pd.read_csv(path.join(database_dir, 'config'), sep='\t', comment='#', header=None, names=['db_prefix', 'name', 'description']) return config['db_prefix'].tolist() diff --git a/staramr/blast/pointfinder/PointfinderDatabaseInfo.py b/staramr/blast/pointfinder/PointfinderDatabaseInfo.py index 25c9d681..f2c0bcca 100644 --- a/staramr/blast/pointfinder/PointfinderDatabaseInfo.py +++ b/staramr/blast/pointfinder/PointfinderDatabaseInfo.py @@ -1,18 +1,27 @@ import pandas as pd +import logging + +from os import path """ A Class storing information about the specific PointFinder database. """ +logger = logging.getLogger('PointfinderDatabaseInfo') + class PointfinderDatabaseInfo: - def __init__(self, database_info_dataframe): + def __init__(self, database_info_dataframe, file=None): """ Creates a new PointfinderDatabaseInfo. :param database_info_dataframe: A pd.DataFrame containing the information in PointFinder. + :param file: The file where the pointfinder database info originates from. """ self._pointfinder_info = database_info_dataframe + self._file = file + + self._resistance_table_hacks(self._pointfinder_info) @classmethod def from_file(cls, file): @@ -22,8 +31,8 @@ def from_file(cls, file): :param file: The file containing drug resistance mutations. :return: A new PointfinderDatabaseInfo. """ - pointfinder_info = pd.read_table(file, index_col=False) - return cls(pointfinder_info) + pointfinder_info = pd.read_csv(file, sep='\t', index_col=False) + return cls(pointfinder_info, file) @classmethod def from_pandas_table(cls, database_info_dataframe): @@ -34,6 +43,18 @@ def from_pandas_table(cls, database_info_dataframe): """ return cls(database_info_dataframe) + def _resistance_table_hacks(self, table): + """ + A function implementing some hacks to try and fix mismatched strings in the pointfinder databases. + These should be removed when the underlying database is corrected. + :param table: The pointfinder resistance table to fix. + :return: None, but modifies the passed table in place. + """ + if self._file and 'salmonella' in str(self._file) and path.exists( + path.join(path.dirname(self._file), '16S_rrsD.fsa')): + logger.debug("Replacing [16S] with [16S_rrsD] for pointfinder organism [salmonella]") + table[['#Gene_ID']] = table[['#Gene_ID']].replace('16S', '16S_rrsD') + def _get_resistance_codon_match(self, gene, codon_mutation): table = self._pointfinder_info diff --git a/staramr/blast/resfinder/ResfinderBlastDatabase.py b/staramr/blast/resfinder/ResfinderBlastDatabase.py index d0c22ec2..fa80ee24 100644 --- a/staramr/blast/resfinder/ResfinderBlastDatabase.py +++ b/staramr/blast/resfinder/ResfinderBlastDatabase.py @@ -25,3 +25,6 @@ def get_database_names(self): def get_path(self, database_name): return os.path.join(self.database_dir, database_name + self.fasta_suffix) + + def get_name(self): + return 'resfinder' diff --git a/staramr/blast/results/AMRHitHSP.py b/staramr/blast/results/AMRHitHSP.py index 82c44bbc..d3254391 100644 --- a/staramr/blast/results/AMRHitHSP.py +++ b/staramr/blast/results/AMRHitHSP.py @@ -101,14 +101,14 @@ def get_genome_contig_id(self): re_search = re.search(r'^(\S+)', self._blast_record['sseqid']) return re_search.group(1) - def get_genome_contig_start(self): + def get_genome_contig_start(self) -> int: """ Gets the start of the HSP in the genome input file. :return: The start of the HSP. """ return self._blast_record['sstart'] - def get_genome_contig_end(self): + def get_genome_contig_end(self) -> int: """ Gets the end of the HSP in the genome input file. :return: The end of the HSP. diff --git a/staramr/blast/results/BlastHitPartitions.py b/staramr/blast/results/BlastHitPartitions.py index ce3dee86..0c0a9b2b 100644 --- a/staramr/blast/results/BlastHitPartitions.py +++ b/staramr/blast/results/BlastHitPartitions.py @@ -3,7 +3,7 @@ from typing import List from typing import Optional from typing import Tuple -from typing import Union +from typing import Any from collections import OrderedDict logger = logging.getLogger('BlastHits') @@ -41,7 +41,7 @@ def append(self, hit: AMRHitHSP) -> None: else: self._add_hit_partition(hit, partition) - def _add_hit_partition(self, hit: AMRHitHSP, partition: Dict[str, Union[int, List[AMRHitHSP]]]) -> None: + def _add_hit_partition(self, hit: AMRHitHSP, partition: Dict[str, Any]) -> None: start, end = self._stranded_ends(hit) if start < partition['start']: @@ -52,7 +52,7 @@ def _add_hit_partition(self, hit: AMRHitHSP, partition: Dict[str, Union[int, Lis partition['hits'].append(hit) - def _get_existing_partition(self, hit: AMRHitHSP) -> Optional[Dict[str, Union[int, List[AMRHitHSP]]]]: + def _get_existing_partition(self, hit: AMRHitHSP) -> Optional[Dict[str, Any]]: partition_name = hit.get_genome_contig_id() if partition_name in self._partitions: @@ -63,7 +63,7 @@ def _get_existing_partition(self, hit: AMRHitHSP) -> Optional[Dict[str, Union[in return None - def _hit_in_parition(self, hit: AMRHitHSP, partition: Dict[str, Union[int, List[AMRHitHSP]]]) -> bool: + def _hit_in_parition(self, hit: AMRHitHSP, partition: Dict[str, Any]) -> bool: pstart, pend = partition['start'], partition['end'] start, end = self._stranded_ends(hit) diff --git a/staramr/blast/results/BlastResultsParser.py b/staramr/blast/results/BlastResultsParser.py index 031f4ee6..cce31ab5 100644 --- a/staramr/blast/results/BlastResultsParser.py +++ b/staramr/blast/results/BlastResultsParser.py @@ -1,10 +1,11 @@ import abc import logging import os +from typing import List import Bio.SeqIO -import pandas as pd import numpy as np +import pandas as pd from staramr.blast.BlastHandler import BlastHandler from staramr.blast.results.BlastHitPartitions import BlastHitPartitions @@ -17,14 +18,14 @@ class BlastResultsParser: - INDEX = 'Isolate ID' - COLUMNS = None - SORT_COLUMNS = None + INDEX = 'Isolate ID' # type: str + COLUMNS = [] # type: List[str] + SORT_COLUMNS = [] # type: List[str] BLAST_SORT_COLUMNS = [x.strip() for x in ''' plength pident sstart - '''.strip().split('\n')] + '''.strip().split('\n')] # type: List[str] def __init__(self, file_blast_map, blast_database, pid_threshold, plength_threshold, report_all=False, output_dir=None, genes_to_exclude=[]): @@ -85,7 +86,8 @@ def _get_out_file_name(self, in_file): pass def _handle_blast_hit(self, in_file, database_name, blast_file, results, hit_seq_records): - blast_table = pd.read_table(blast_file, header=None, names=BlastHandler.BLAST_COLUMNS, index_col=False).astype( + blast_table = pd.read_csv(blast_file, sep='\t', header=None, names=BlastHandler.BLAST_COLUMNS, + index_col=False).astype( dtype={'qseqid': np.unicode_, 'sseqid': np.unicode_}) partitions = BlastHitPartitions() diff --git a/staramr/blast/results/pointfinder/BlastResultsParserPointfinder.py b/staramr/blast/results/pointfinder/BlastResultsParserPointfinder.py index 021bf3a8..437f52a2 100644 --- a/staramr/blast/results/pointfinder/BlastResultsParserPointfinder.py +++ b/staramr/blast/results/pointfinder/BlastResultsParserPointfinder.py @@ -45,7 +45,7 @@ def __init__(self, file_blast_map, blast_database, pid_threshold, plength_thresh def _create_hit(self, file, database_name, blast_record): logger.debug("database_name=%s", database_name) - if database_name == '16S_rrsD': + if (database_name == '16S_rrsD') or (database_name == '23S'): return PointfinderHitHSPRNA(file, blast_record) else: return PointfinderHitHSP(file, blast_record) @@ -72,7 +72,7 @@ def _get_result_rows(self, hit, database_name): for x in database_mutations: logger.debug("database_mutations: position=%s, mutation=%s", x.get_mutation_position(), x.get_mutation_string()) - if database_name == '16S_rrsD': + if (database_name == '16S_rrsD') or (database_name == '23S'): database_resistance_mutations = self._blast_database.get_resistance_nucleotides(gene, database_mutations) else: database_resistance_mutations = self._blast_database.get_resistance_codons(gene, database_mutations) diff --git a/staramr/databases/AMRDatabaseHandler.py b/staramr/databases/AMRDatabaseHandler.py deleted file mode 100644 index be99c750..00000000 --- a/staramr/databases/AMRDatabaseHandler.py +++ /dev/null @@ -1,152 +0,0 @@ -import logging -import shutil -import time -from collections import OrderedDict -from os import path - -import git - -from staramr.exceptions.DatabaseErrorException import DatabaseErrorException -from staramr.exceptions.DatabaseNotFoundException import DatabaseNotFoundException - -logger = logging.getLogger('AMRDatabaseHandler') - -""" -A Class used to handle interactions with the ResFinder/PointFinder database files. -""" - - -class AMRDatabaseHandler: - TIME_FORMAT = "%a, %d %b %Y %H:%M" - - def __init__(self, database_dir): - """ - Creates a new AMRDatabaseHandler. - :param database_dir: The root directory for both the ResFinder/PointFinder databases. - """ - self._database_dir = database_dir - self._resfinder_dir = path.join(database_dir, 'resfinder') - self._pointfinder_dir = path.join(database_dir, 'pointfinder') - - self._resfinder_url = "https://bitbucket.org/genomicepidemiology/resfinder_db.git" - self._pointfinder_url = "https://bitbucket.org/genomicepidemiology/pointfinder_db.git" - - def build(self, resfinder_commit=None, pointfinder_commit=None): - """ - Downloads and builds a new ResFinder/PointFinder database. - :param resfinder_commit: The specific git commit for ResFinder. - :param pointfinder_commit: The specific git commit for PointFinder. - :return: None - """ - - try: - logger.info("Cloning resfinder db [%s] to [%s]", self._resfinder_url, self._resfinder_dir) - resfinder_repo = git.repo.base.Repo.clone_from(self._resfinder_url, self._resfinder_dir) - - if resfinder_commit is not None: - logger.info("Checking out resfinder commit %s", resfinder_commit) - resfinder_repo.git.checkout(resfinder_commit) - - logger.info("Cloning pointfinder db [%s] to [%s]", self._pointfinder_url, self._pointfinder_dir) - pointfinder_repo = git.repo.base.Repo.clone_from(self._pointfinder_url, self._pointfinder_dir) - - if pointfinder_commit is not None: - logger.info("Checking out pointfinder commit %s", pointfinder_commit) - pointfinder_repo.git.checkout(pointfinder_commit) - except Exception as e: - raise DatabaseErrorException("Could not build database in [" + self._database_dir + "]") from e - - def update(self, resfinder_commit=None, pointfinder_commit=None): - """ - Updates an existing ResFinder/PointFinder database to the latest revisions (or passed specific revisions). - :param resfinder_commit: The specific git commit for ResFinder. - :param pointfinder_commit: The specific git commit for PointFinder. - :return: None - """ - - if not path.exists(self._database_dir): - self.build(resfinder_commit=resfinder_commit, pointfinder_commit=pointfinder_commit) - else: - try: - resfinder_repo = git.Repo(self._resfinder_dir) - pointfinder_repo = git.Repo(self._pointfinder_dir) - - logger.info("Updating %s", self._resfinder_dir) - resfinder_repo.heads.master.checkout() - resfinder_repo.remotes.origin.pull() - - if resfinder_commit is not None: - logger.info("Checking out resfinder commit %s", resfinder_commit) - resfinder_repo.git.checkout(resfinder_commit) - - resfinder_repo.git.reset('--hard') - - logger.info("Updating %s", self._pointfinder_dir) - pointfinder_repo.heads.master.checkout() - pointfinder_repo.remotes.origin.pull() - - if pointfinder_commit is not None: - logger.info("Checking out pointfinder commit %s", pointfinder_commit) - pointfinder_repo.git.checkout(pointfinder_commit) - - resfinder_repo.git.reset('--hard') - except Exception as e: - raise DatabaseErrorException("Could not build database in [" + self._database_dir + "]") from e - - def remove(self): - """ - Removes the databases stored in this directory. - :return: None - """ - shutil.rmtree(self._database_dir) - - def info(self): - """ - Gets information on the ResFinder/PointFinder databases. - :return: Database information as a OrderedDict of key/value pairs. - """ - data = OrderedDict() - - try: - resfinder_repo = git.Repo(self._resfinder_dir) - resfinder_repo_head = resfinder_repo.commit('HEAD') - - data['resfinder_db_dir'] = self._resfinder_dir - data['resfinder_db_url'] = self._resfinder_url - data['resfinder_db_commit'] = str(resfinder_repo_head) - data['resfinder_db_date'] = time.strftime(self.TIME_FORMAT, time.gmtime(resfinder_repo_head.committed_date)) - - pointfinder_repo = git.Repo(self._pointfinder_dir) - pointfinder_repo_head = pointfinder_repo.commit('HEAD') - - data['pointfinder_db_dir'] = self._pointfinder_dir - data['pointfinder_db_url'] = self._pointfinder_url - data['pointfinder_db_commit'] = str(pointfinder_repo_head) - data['pointfinder_db_date'] = time.strftime(self.TIME_FORMAT, - time.gmtime(pointfinder_repo_head.committed_date)) - - except git.exc.NoSuchPathError as e: - raise DatabaseNotFoundException('Invalid database in [' + self._database_dir + ']') from e - - return data - - def get_database_dir(self): - """ - Gets the root database dir. - :return: The root database dir. - """ - return self._database_dir - - def get_resfinder_dir(self): - """ - Gets the ResFinder database directory. - :return: The ResFinder database directory. - """ - return self._resfinder_dir - - def get_pointfinder_dir(self): - """ - Gets the PointFinder database directory. - :return: The PointFinder database directory. - """ - return self._pointfinder_dir diff --git a/staramr/databases/AMRDatabaseHandlerStripGitDir.py b/staramr/databases/AMRDatabaseHandlerStripGitDir.py deleted file mode 100644 index 6210f605..00000000 --- a/staramr/databases/AMRDatabaseHandlerStripGitDir.py +++ /dev/null @@ -1,99 +0,0 @@ -import configparser -import logging -import shutil -from collections import OrderedDict -from os import path - -from staramr.databases.AMRDatabaseHandler import AMRDatabaseHandler -from staramr.exceptions.DatabaseNotFoundException import DatabaseNotFoundException - -logger = logging.getLogger('AMRDatabaseHandlerStripGitDir') - -""" -A Class used to handle interactions with the ResFinder/PointFinder database files, stripping out the .git directory. -""" - - -class AMRDatabaseHandlerStripGitDir(AMRDatabaseHandler): - GIT_INFO_SECTION = 'GitInfo' - - def __init__(self, database_dir): - """ - Creates a new AMRDatabaseHandlerStripGitDir. - :param database_dir: The root directory for both the ResFinder/PointFinder databases. - """ - super().__init__(database_dir) - - self._resfinder_dir_git = path.join(self._resfinder_dir, '.git') - self._pointfinder_dir_git = path.join(self._pointfinder_dir, '.git') - self._info_file = path.join(database_dir, 'info.ini') - - def build(self, resfinder_commit=None, pointfinder_commit=None): - """ - Downloads and builds a new ResFinder/PointFinder database. - :param resfinder_commit: The specific git commit for ResFinder. - :param pointfinder_commit: The specific git commit for PointFinder. - :return: None - """ - super().build(resfinder_commit=resfinder_commit, pointfinder_commit=pointfinder_commit) - - database_info = super().info() - - # remove directories from info as they are unimportant here - database_info_stripped = OrderedDict(database_info) - del database_info_stripped['resfinder_db_dir'] - del database_info_stripped['pointfinder_db_dir'] - - self._write_database_info_to_file(database_info_stripped, self._info_file) - - logger.info("Removing %s", self._resfinder_dir_git) - shutil.rmtree(self._resfinder_dir_git) - logger.info("Removing %s", self._pointfinder_dir_git) - shutil.rmtree(self._pointfinder_dir_git) - - def _write_database_info_to_file(self, database_info, file): - config = configparser.ConfigParser() - config[self.GIT_INFO_SECTION] = database_info - - with open(file, 'w') as file_handle: - config.write(file_handle) - - def _read_database_info_from_file(self, file): - config = configparser.ConfigParser() - config.read(file) - return OrderedDict(config[self.GIT_INFO_SECTION]) - - def update(self, resfinder_commit=None, pointfinder_commit=None): - """ - Updates an existing ResFinder/PointFinder database to the latest revisions (or passed specific revisions). - :param resfinder_commit: The specific git commit for ResFinder. - :param pointfinder_commit: The specific git commit for PointFinder. - :return: None - """ - raise Exception("Cannot update when .git directory has been removed") - - def info(self): - """ - Gets information on the ResFinder/PointFinder databases. - :return: Database information as a list containing key/value pairs. - """ - - try: - data = self._read_database_info_from_file(self._info_file) - data['resfinder_db_dir'] = self._resfinder_dir - data['pointfinder_db_dir'] = self._pointfinder_dir - - # re-order all fields - data.move_to_end('resfinder_db_dir', last=True) - data.move_to_end('resfinder_db_url', last=True) - data.move_to_end('resfinder_db_commit', last=True) - data.move_to_end('resfinder_db_date', last=True) - - data.move_to_end('pointfinder_db_dir', last=True) - data.move_to_end('pointfinder_db_url', last=True) - data.move_to_end('pointfinder_db_commit', last=True) - data.move_to_end('pointfinder_db_date', last=True) - - return data - except FileNotFoundError as e: - raise DatabaseNotFoundException('Database could not be found in [' + self._database_dir + ']') from e diff --git a/staramr/databases/AMRDatabasesManager.py b/staramr/databases/AMRDatabasesManager.py index ab634b56..dfa3436f 100644 --- a/staramr/databases/AMRDatabasesManager.py +++ b/staramr/databases/AMRDatabasesManager.py @@ -1,8 +1,7 @@ import logging from os import path -from staramr.databases.AMRDatabaseHandler import AMRDatabaseHandler -from staramr.databases.AMRDatabaseHandlerStripGitDir import AMRDatabaseHandlerStripGitDir +from staramr.databases.BlastDatabaseRepositories import BlastDatabaseRepositories logger = logging.getLogger('AMRDatabasesManager') @@ -12,10 +11,12 @@ class AMRDatabasesManager: - DEFAULT_RESFINDER_COMMIT = 'e8f1eb2585cd9610c4034a54ce7fc4f93aa95535' - DEFAULT_POINTFINDER_COMMIT = '8706a6363bb29e47e0e398c53043b037c24b99a7' + DEFAULT_COMMITS = { + 'resfinder': 'e8f1eb2585cd9610c4034a54ce7fc4f93aa95535', + 'pointfinder': '8706a6363bb29e47e0e398c53043b037c24b99a7' + } - def __init__(self, database_dir, sub_dirs=False): + def __init__(self, database_dir: str, sub_dirs: bool = False) -> None: """ Builds a new AMRDatabasesManager with the passed directory. :param database_dir: The directory containing the ResFinder/PointFinder databases. @@ -27,18 +28,18 @@ def __init__(self, database_dir, sub_dirs=False): self._git_strip_database_dir = path.join(database_dir, 'dist') self._sub_dirs = sub_dirs - def get_database_handler(self, force_use_git=False): + def get_database_repos(self, force_use_git: bool = False) -> BlastDatabaseRepositories: """ - Gets the appropriate database handler. - :param force_use_git: Force use of git database handler. - :return: The database handler. + Gets the appropriate database repositories. + :param force_use_git: Force use of git database repos. + :return: The database repos object. """ if self._sub_dirs and (force_use_git or path.exists(self._git_database_dir)): - return AMRDatabaseHandler(self._git_database_dir) + return BlastDatabaseRepositories.create_default_repositories(self._git_database_dir) elif self._sub_dirs: - return AMRDatabaseHandlerStripGitDir(self._git_strip_database_dir) + return BlastDatabaseRepositories.create_default_repositories(self._git_strip_database_dir, is_dist=True) else: - return AMRDatabaseHandler(self._database_dir) + return BlastDatabaseRepositories.create_default_repositories(self._git_database_dir) def setup_default(self): """ @@ -50,9 +51,9 @@ def setup_default(self): logger.warning("Default database already exists in [%s]", self._git_strip_database_dir) else: logger.info("Setting up default database in [%s]", self._git_strip_database_dir) - database_handler = AMRDatabaseHandlerStripGitDir(self._git_strip_database_dir) - database_handler.build(resfinder_commit=self.DEFAULT_RESFINDER_COMMIT, - pointfinder_commit=self.DEFAULT_POINTFINDER_COMMIT) + database_repos = BlastDatabaseRepositories.create_default_repositories(self._git_strip_database_dir, + is_dist=True) + database_repos.build(self.DEFAULT_COMMITS) def restore_default(self): """ @@ -62,8 +63,8 @@ def restore_default(self): if path.exists(self._git_database_dir): logger.info("Removing database in [%s]", self._git_database_dir) - database_handler = AMRDatabaseHandler(self._git_database_dir) - database_handler.remove() + database_repos = BlastDatabaseRepositories.create_default_repositories(self._git_database_dir) + database_repos.remove() if not path.exists(self._git_strip_database_dir): self.setup_default() @@ -76,19 +77,16 @@ def restore_default(self): logger.info("Default database already in use under directory [%s]", self._git_strip_database_dir) @classmethod - def is_handler_default_commits(self, amr_database_handler: AMRDatabaseHandler) -> bool: + def is_database_repos_default_commits(self, database_repos: BlastDatabaseRepositories) -> bool: """ - Checks whether the past database handler is linked to default commits of the database. - :param amr_database_handler: The database handler. + Checks whether the past database repos handler is linked to default commits of the database. + :param database_repos: The database repos handler. :return: True if it's setup with default commit versions, false otherwise. """ - database_info = amr_database_handler.info() - - return database_info['resfinder_db_commit'] == self.DEFAULT_RESFINDER_COMMIT and database_info[ - 'pointfinder_db_commit'] == self.DEFAULT_POINTFINDER_COMMIT + return database_repos.is_at_commits(self.DEFAULT_COMMITS) @classmethod - def get_default_database_directory(cls): + def get_default_database_directory(cls) -> str: """ Class method for getting the default database root directory. :return: The default database root directory. diff --git a/staramr/databases/BlastDatabaseRepositories.py b/staramr/databases/BlastDatabaseRepositories.py new file mode 100644 index 00000000..641bd4fe --- /dev/null +++ b/staramr/databases/BlastDatabaseRepositories.py @@ -0,0 +1,157 @@ +import logging +import shutil +from collections import OrderedDict +from typing import Dict + +from staramr.databases.BlastDatabaseRepository import BlastDatabaseRepository, BlastDatabaseRepositoryStripGitDir +from staramr.blast.AbstractBlastDatabase import AbstractBlastDatabase +from staramr.blast.resfinder.ResfinderBlastDatabase import ResfinderBlastDatabase +from staramr.blast.pointfinder.PointfinderBlastDatabase import PointfinderBlastDatabase + +logger = logging.getLogger('BlastDatabaseRepositories') + +""" +A Class used to handle interactions with blast database repository files. +""" + + +class BlastDatabaseRepositories: + + def __init__(self, database_dir: str, is_dist: bool = False) -> None: + """ + Creates a new AMRDatabaseHandler. + :param database_dir: The root directory for the databases. + :param is_dist: Whether or not we are building distributable versions of the blast database repositories + (that is, should we strip out the .git directories). + """ + self._database_dir = database_dir + self._database_repositories = {} # type: Dict[str,BlastDatabaseRepository] + self._is_dist = is_dist + + def register_database_repository(self, database_name: str, git_repository_url: str) -> None: + """ + Registers a new database repository. + :param database_name: The name of the database. + :param git_repository_url: The git repository url. + :param is_dist: True if this database should be interpreted as the distributable version (no .git directory). + :return: None + """ + database_repository = BlastDatabaseRepository(self._database_dir, database_name, + git_repository_url) # type: BlastDatabaseRepository + if self._is_dist: + database_repository = BlastDatabaseRepositoryStripGitDir(self._database_dir, database_name, + git_repository_url) + + if database_name in self._database_repositories: + raise Exception("A database with name [{}] already exists", database_name) + else: + self._database_repositories[database_name] = database_repository + + def build(self, commits: Dict[str, str] = None): + """ + Downloads and builds new databases. + :param commits: A map of {'database_name' : 'commit'} defining the particular commits to build. + :return: None + """ + for database_name in self._database_repositories: + commit = commits.get(database_name) if commits else None + self._database_repositories[database_name].build(commit) + + def update(self, commits: Dict[str, str] = None): + """ + Updates an existing database to the latest revisions (or passed specific revisions). + :param commits: A map of {'database_name' : 'commit'} defining the particular commits to update to. + :return: None + """ + for database_name in self._database_repositories: + commit = commits.get(database_name) if commits else None + self._database_repositories[database_name].update(commit) + + def remove(self): + """ + Removes the databases stored in this directory. + :return: None + """ + for name, repo in self._database_repositories.items(): + repo.remove() + + shutil.rmtree(self._database_dir) + + def info(self) -> Dict[str, str]: + """ + Gets information on the ResFinder/PointFinder databases. + :return: Database information as a OrderedDict of key/value pairs. + """ + info = OrderedDict() # type: Dict[str,str] + + for name, repo in self._database_repositories.items(): + info.update(repo.info()) + + return info + + def get_database_dir(self) -> str: + """ + Gets the root database dir. + :return: The root database dir. + """ + return self._database_dir + + def get_repo_dir(self, name: str) -> str: + """ + Gets database repo directory for the given database name. + :param name: The database name. + :return: The database dir for the given database name. + """ + return self._database_repositories[name].get_git_dir() + + def is_at_commits(self, commits: Dict[str, str]): + """ + Are the database repositories at the passed commits? + :param commits: A dict of the commits {'database_name': 'commit'}. + :return: True if the database repositories are at the passed commits (ignores repos not passed in dict). False otherwise. + """ + for name, repo in self._database_repositories.items(): + if name in commits and not repo.is_at_commit(commits[name]): + return False + + return True + + def is_dist(self): + """ + Whether or not we are building distributable versions of the blast database repositories (that is, should we strip out the .git directories). + :return: True if is_dist, False otherwise. + """ + return self._is_dist + + @classmethod + def create_default_repositories(cls, root_database_dir: str, is_dist: bool = False): + """ + Class method for creating a BlastDatabaseRepositories object configured with the default repositories. + :param database_dir: The root database directory. + :param is_dist: Whether or not we are building distributable versions of the blast database repositories + (that is, should we strip out the .git directories). + :return: The BlastDatabaseRepositories. + """ + repos = cls(root_database_dir, is_dist) + repos.register_database_repository('resfinder', 'https://bitbucket.org/genomicepidemiology/resfinder_db.git') + repos.register_database_repository('pointfinder', + 'https://bitbucket.org/genomicepidemiology/pointfinder_db.git') + + return repos + + def build_blast_database(self, database_name: str, options: Dict[str, str] = {}) -> AbstractBlastDatabase: + """ + Builds a staramr.blast.AbstractBlastDatabase from the given parameters. + :param database_name: The name of the database to build. + :param options: Options for the particular database in the form of a map {'key': 'value'} + :return: A new staramr.blast.AbstractBlastDatabase. + """ + if database_name not in self._database_repositories: + raise Exception("database_name={} not registered", database_name) + + if database_name == 'resfinder': + return ResfinderBlastDatabase(self.get_repo_dir(database_name)) + elif database_name == 'pointfinder': + return PointfinderBlastDatabase(self.get_repo_dir(database_name), options['organism']) + else: + raise Exception("Unknown database name [{}]", database_name) diff --git a/staramr/databases/BlastDatabaseRepository.py b/staramr/databases/BlastDatabaseRepository.py new file mode 100644 index 00000000..5758b7e0 --- /dev/null +++ b/staramr/databases/BlastDatabaseRepository.py @@ -0,0 +1,211 @@ +import configparser +import logging +import shutil +import time +from collections import OrderedDict +from os import path +from typing import Dict + +import git + +from staramr.exceptions.DatabaseErrorException import DatabaseErrorException +from staramr.exceptions.DatabaseNotFoundException import DatabaseNotFoundException + +""" +A Class used to handle interactions with the BLAST database repositories. +""" + + +class BlastDatabaseRepository: + TIME_FORMAT = "%a, %d %b %Y %H:%M" + LOGGER = logging.getLogger('BlastDatabaseRepository') + + def __init__(self, database_root_dir: str, database_name: str, git_repository_url: str) -> None: + """ + Creates a new BlastDatabaseRepository. + :param database_root_dir: The root directory for both the Blast databases. + :param database_name: A name for this database. + :param git_repository_url: A URL to the git repository managing the database files. + """ + self._database_dir = database_root_dir + self._database_name = database_name + self._git_repository_url = git_repository_url + + self._git_dir = path.join(database_root_dir, database_name) + + def build(self, commit: str = None) -> None: + """ + Downloads and builds a new Blast database. + :param commit: The specific git commit to download. Defaults to latest commit. + :return: None + """ + + try: + self.LOGGER.info("Cloning %s db [%s] to [%s]", self._database_name, self._git_repository_url, self._git_dir) + repo = git.repo.base.Repo.clone_from(self._git_repository_url, self._git_dir) + + if commit is not None: + self.LOGGER.info("Checking out %s commit %s", self._database_name, commit) + repo.git.checkout(commit) + except Exception as e: + raise DatabaseErrorException("Could not build database in [" + self._database_dir + "]") from e + + def update(self, commit: str = None) -> None: + """ + Updates an existing Blast database to the latest revisions (or passed specific revisions). + :param commit: The specific git commit to update to. Defaults to latest commit. + :return: None + """ + + if not path.exists(self._git_dir): + self.build(commit=commit) + else: + try: + repo = git.Repo(self._git_dir) + + self.LOGGER.info("Updating %s", self._git_dir) + repo.heads.master.checkout() + repo.remotes.origin.pull() + + if commit is not None: + self.LOGGER.info("Checking out %s commit %s", self._database_name, commit) + repo.git.checkout(commit) + + repo.git.reset('--hard') + except Exception as e: + raise DatabaseErrorException("Could not build database in [" + self._database_dir + "]") from e + + def remove(self): + """ + Removes the databases stored in this directory. + :return: None + """ + shutil.rmtree(self._git_dir) + + def is_at_commit(self, commit: str = None) -> bool: + """ + Determines whether this database repo is at the specified commit. + :param commit: The commit to check. + :return: True if the database is at the specified commit, otherwise False. + """ + return self.info()[self._get_info_name('commit')] == commit + + def info(self) -> Dict[str, str]: + """ + Gets information on the Blast databases. + :return: Database information as a OrderedDict of key/value pairs. + """ + data = OrderedDict() # type: Dict[str,str] + + try: + repo = git.Repo(self._git_dir) + repo_head = repo.commit('HEAD') + + data[self._get_info_name('dir')] = self._git_dir + data[self._get_info_name('url')] = self._git_repository_url + data[self._get_info_name('commit')] = str(repo_head) + data[self._get_info_name('date')] = time.strftime(self.TIME_FORMAT, + time.gmtime(repo_head.committed_date)) + + except git.exc.NoSuchPathError as e: + raise DatabaseNotFoundException('Invalid database in [' + self._database_dir + ']') from e + + return data + + def _get_info_name(self, info_type): + return self._database_name + '_db_' + info_type + + def get_database_dir(self) -> str: + """ + Gets the root database dir. + :return: The root database dir. + """ + return self._database_dir + + def get_git_dir(self) -> str: + """ + Gets the database git directory. + :return: The database git directory. + """ + return self._git_dir + + +""" +A Class used to handle interactions with the BLAST database repositories, stripping out the .git directory. +""" + + +class BlastDatabaseRepositoryStripGitDir(BlastDatabaseRepository): + GIT_INFO_SECTION = 'GitInfo' + LOGGER = logging.getLogger('BlastDatabaseRepositoryStripGitDir') + + def __init__(self, database_root_dir: str, database_name: str, git_repository_url: str) -> None: + """ + Creates a new BlastDatabaseRepositoryStripGitDir. + :param database_root_dir: The root directory for both the Blast databases. + :param database_name: A name for this database. + :param git_repository_url: A URL to the git repository managing the database files. + """ + super().__init__(database_root_dir, database_name, git_repository_url) + + self._git_dot_git_dir = path.join(self._git_dir, '.git') + self._info_file = path.join(database_root_dir, database_name + '-info.ini') + + def build(self, commit: str = None): + """ + Downloads and builds a new Blast database. + :param commit: The specific git commit to download. Defaults to latest commit. + :return: None + """ + super().build(commit=commit) + + database_info = super().info() + + # remove directories from info as they are unimportant here + database_info_stripped = OrderedDict(database_info) + del database_info_stripped[self._get_info_name('dir')] + + self._write_database_info_to_file(database_info_stripped, self._info_file) + + self.LOGGER.info("Removing %s", self._git_dot_git_dir) + shutil.rmtree(self._git_dot_git_dir) + + def _write_database_info_to_file(self, database_info, file): + config = configparser.ConfigParser() + config[self.GIT_INFO_SECTION] = database_info + + with open(file, 'w') as file_handle: + config.write(file_handle) + + def _read_database_info_from_file(self, file): + config = configparser.ConfigParser() + config.read(file) + return OrderedDict(config[self.GIT_INFO_SECTION]) + + def update(self, commit: str = None) -> None: + """ + Updates an existing Blast database to the latest revisions (or passed specific revisions). + :param commit: The commit to update to. + :return: None + """ + raise Exception("Cannot update when .git directory has been removed") + + def info(self) -> Dict[str, str]: + """ + Gets information on the ResFinder/PointFinder databases. + :return: Database information as a list containing key/value pairs. + """ + + try: + data = self._read_database_info_from_file(self._info_file) + data[self._get_info_name('dir')] = self._git_dir + + # re-order all fields + data.move_to_end(self._get_info_name('dir'), last=True) + data.move_to_end(self._get_info_name('url'), last=True) + data.move_to_end(self._get_info_name('commit'), last=True) + data.move_to_end(self._get_info_name('date'), last=True) + + return data + except FileNotFoundError as e: + raise DatabaseNotFoundException('Database could not be found in [' + self.get_database_dir() + ']') from e diff --git a/staramr/databases/exclude/ExcludeGenesList.py b/staramr/databases/exclude/ExcludeGenesList.py index e091ec27..9b7c8601 100644 --- a/staramr/databases/exclude/ExcludeGenesList.py +++ b/staramr/databases/exclude/ExcludeGenesList.py @@ -10,7 +10,7 @@ class ExcludeGenesList: DEFAULT_EXCLUDE_FILE = path.join(path.dirname(__file__), 'data', 'genes_to_exclude.tsv') def __init__(self, file=DEFAULT_EXCLUDE_FILE): - self._data = pd.read_table(file) + self._data = pd.read_csv(file, sep='\t') def tolist(self): """ diff --git a/staramr/databases/resistance/ARGDrugTable.py b/staramr/databases/resistance/ARGDrugTable.py index 8b710bfa..69c8b32f 100644 --- a/staramr/databases/resistance/ARGDrugTable.py +++ b/staramr/databases/resistance/ARGDrugTable.py @@ -23,7 +23,7 @@ def __init__(self, file=None, info_file=DEFAULT_INFO_FILE): self._file = file if file is not None: - self._data = pd.read_table(file) + self._data = pd.read_csv(file, sep='\t') def get_resistance_table_info(self): """ diff --git a/staramr/subcommand/Database.py b/staramr/subcommand/Database.py index 8527c2d3..38ca4ac1 100644 --- a/staramr/subcommand/Database.py +++ b/staramr/subcommand/Database.py @@ -101,11 +101,11 @@ def run(self, args): mkdir(args.destination) if args.destination == AMRDatabasesManager.get_default_database_directory(): - database_handler = AMRDatabasesManager.create_default_manager().get_database_handler() + database_repos = AMRDatabasesManager.create_default_manager().get_database_repos() else: - database_handler = AMRDatabasesManager(args.destination).get_database_handler() - database_handler.build(resfinder_commit=args.resfinder_commit, pointfinder_commit=args.pointfinder_commit) - if not AMRDatabasesManager.is_handler_default_commits(database_handler): + database_repos = AMRDatabasesManager(args.destination).get_database_repos() + database_repos.build({'resfinder': args.resfinder_commit, 'pointfinder': args.pointfinder_commit}) + if not AMRDatabasesManager.is_database_repos_default_commits(database_repos): logger.warning( "Built non-default ResFinder/PointFinder database version. This may lead to " + "differences in the detected AMR genes depending on how the database files are structured.") @@ -159,13 +159,13 @@ def run(self, args): print_help=True) else: try: - database_handler = AMRDatabasesManager.create_default_manager().get_database_handler( + database_repos = AMRDatabasesManager.create_default_manager().get_database_repos( force_use_git=True) - database_handler.update(resfinder_commit=args.resfinder_commit, - pointfinder_commit=args.pointfinder_commit) + database_repos.update( + {'resfinder': args.resfinder_commit, 'pointfinder': args.pointfinder_commit}) - if not AMRDatabasesManager.is_handler_default_commits(database_handler): + if not AMRDatabasesManager.is_database_repos_default_commits(database_repos): logger.warning( "Updated to non-default ResFinder/PointFinder database version. This may lead to " + "differences in the detected AMR genes depending on how the database files are structured.") @@ -175,10 +175,9 @@ def run(self, args): raise e else: for directory in args.directories: - database_handler = AMRDatabasesManager(directory).get_database_handler() - database_handler.update(resfinder_commit=args.resfinder_commit, - pointfinder_commit=args.pointfinder_commit) - if not AMRDatabasesManager.is_handler_default_commits(database_handler): + database_repos = AMRDatabasesManager(directory).get_database_repos() + database_repos.update({'resfinder': args.resfinder_commit, 'pointfinder': args.pointfinder_commit}) + if not AMRDatabasesManager.is_database_repos_default_commits(database_repos): logger.warning( "Updated to non-default ResFinder/PointFinder database version [%s]. This may lead to " + "differences in the detected AMR genes depending on how the database files are structured.", @@ -280,14 +279,14 @@ def run(self, args): arg_drug_table = ARGDrugTable() if len(args.directories) == 0: - database_handler = AMRDatabasesManager.create_default_manager().get_database_handler() - if not AMRDatabasesManager.is_handler_default_commits(database_handler): + database_repos = AMRDatabasesManager.create_default_manager().get_database_repos() + if not AMRDatabasesManager.is_database_repos_default_commits(database_repos): logger.warning( "Using non-default ResFinder/PointFinder database versions. This may lead to differences in the detected " + "AMR genes depending on how the database files are structured.") try: - database_info = database_handler.info() + database_info = database_repos.info() database_info.update(arg_drug_table.get_resistance_table_info()) sys.stdout.write(get_string_with_spacing(database_info)) except DatabaseNotFoundException as e: @@ -295,14 +294,14 @@ def run(self, args): else: for directory in args.directories: try: - database_handler = AMRDatabasesManager(directory).get_database_handler() - if not AMRDatabasesManager.is_handler_default_commits(database_handler): + database_repos = AMRDatabasesManager(directory).get_database_repos() + if not AMRDatabasesManager.is_database_repos_default_commits(database_repos): logger.warning( "Using non-default ResFinder/PointFinder database version [%s]. This may lead to " + "differences in the detected AMR genes depending on how the database files are structured.", directory) - database_info = database_handler.info() + database_info = database_repos.info() database_info.update(arg_drug_table.get_resistance_table_info()) sys.stdout.write(get_string_with_spacing(database_info)) except DatabaseNotFoundException as e: diff --git a/staramr/subcommand/Search.py b/staramr/subcommand/Search.py index df592af5..1ea9d042 100644 --- a/staramr/subcommand/Search.py +++ b/staramr/subcommand/Search.py @@ -15,10 +15,10 @@ from staramr.blast.pointfinder.PointfinderBlastDatabase import PointfinderBlastDatabase from staramr.blast.resfinder.ResfinderBlastDatabase import ResfinderBlastDatabase from staramr.databases.AMRDatabasesManager import AMRDatabasesManager +from staramr.databases.exclude.ExcludeGenesList import ExcludeGenesList from staramr.databases.resistance.ARGDrugTable import ARGDrugTable from staramr.detection.AMRDetectionFactory import AMRDetectionFactory from staramr.exceptions.CommandParseException import CommandParseException -from staramr.databases.exclude.ExcludeGenesList import ExcludeGenesList logger = logging.getLogger("Search") @@ -86,7 +86,8 @@ def _setup_args(self, arg_parser): help='Disable the default exclusion of some genes from ResFinder/PointFinder [False].', required=False) report_group.add_argument('--exclude-genes-file', action='store', dest='exclude_genes_file', - help='A containing a list of ResFinder/PointFinder gene names to exclude from results [{}].'.format(ExcludeGenesList.get_default_exclude_file()), + help='A containing a list of ResFinder/PointFinder gene names to exclude from results [{}].'.format( + ExcludeGenesList.get_default_exclude_file()), default=ExcludeGenesList.get_default_exclude_file(), required=False) report_group.add_argument('--exclude-negatives', action='store_true', dest='exclude_negatives', @@ -189,12 +190,12 @@ def _print_settings_to_file(self, settings, file): file_handle.write(get_string_with_spacing(settings)) file_handle.close() - def _generate_results(self, database_handler, resfinder_database, pointfinder_database, nprocs, include_negatives, + def _generate_results(self, database_repos, resfinder_database, pointfinder_database, nprocs, include_negatives, include_resistances, hits_output, pid_threshold, plength_threshold_resfinder, plength_threshold_pointfinder, report_all_blast, genes_to_exclude, files): """ Runs AMR detection and generates results. - :param database_handler: The database handler. + :param database_repos: The database repos object. :param resfinder_database: The resfinder database. :param pointfinder_database: The pointfinder database. :param nprocs: The number of processing cores to use for BLAST. @@ -214,7 +215,7 @@ def _generate_results(self, database_handler, resfinder_database, pointfinder_da with tempfile.TemporaryDirectory() as blast_out: start_time = datetime.datetime.now() - blast_handler = BlastHandler(resfinder_database, nprocs, blast_out, pointfinder_database) + blast_handler = BlastHandler({'resfinder': resfinder_database, 'pointfinder': pointfinder_database}, nprocs, blast_out) amr_detection_factory = AMRDetectionFactory() amr_detection = amr_detection_factory.build(resfinder_database, blast_handler, pointfinder_database, @@ -233,7 +234,7 @@ def _generate_results(self, database_handler, resfinder_database, pointfinder_da logger.info("Finished. Took %s minutes.", time_difference_minutes) - settings = database_handler.info() + settings = database_repos.info() settings['command_line'] = ' '.join(sys.argv) settings['version'] = self._version settings['start_time'] = start_time.strftime(self.TIME_FORMAT) @@ -280,24 +281,20 @@ def run(self, args): self._root_arg_parser) if args.database == AMRDatabasesManager.get_default_database_directory(): - database_handler = AMRDatabasesManager.create_default_manager().get_database_handler() + database_repos = AMRDatabasesManager.create_default_manager().get_database_repos() else: - database_handler = AMRDatabasesManager(args.database).get_database_handler() + database_repos = AMRDatabasesManager(args.database).get_database_repos() - if not AMRDatabasesManager.is_handler_default_commits(database_handler): + if not AMRDatabasesManager.is_database_repos_default_commits(database_repos): logger.warning("Using non-default ResFinder/PointFinder. This may lead to differences in the detected " + "AMR genes depending on how the database files are structured.") - resfinder_database_dir = database_handler.get_resfinder_dir() - pointfinder_database_dir = database_handler.get_pointfinder_dir() - - resfinder_database = ResfinderBlastDatabase(resfinder_database_dir) + resfinder_database = database_repos.build_blast_database('resfinder') if (args.pointfinder_organism): if args.pointfinder_organism not in PointfinderBlastDatabase.get_available_organisms(): raise CommandParseException("The only Pointfinder organism(s) currently supported are " + str( PointfinderBlastDatabase.get_available_organisms()), self._root_arg_parser) - pointfinder_database = PointfinderBlastDatabase(pointfinder_database_dir, - args.pointfinder_organism) + pointfinder_database = database_repos.build_blast_database('pointfinder', {'organism': args.pointfinder_organism}) else: logger.info("No --pointfinder-organism specified. Will not search the PointFinder databases") pointfinder_database = None @@ -360,10 +357,12 @@ def run(self, args): raise CommandParseException('--exclude-genes-file [{}] does not exist'.format(args.exclude_genes_file), self._root_arg_parser) else: - logger.info("Will exclude ResFinder/PointFinder genes listed in [%s]. Use --no-exclude-genes to disable",args.exclude_genes_file) - exclude_genes=ExcludeGenesList(args.exclude_genes_file).tolist() + logger.info( + "Will exclude ResFinder/PointFinder genes listed in [%s]. Use --no-exclude-genes to disable", + args.exclude_genes_file) + exclude_genes = ExcludeGenesList(args.exclude_genes_file).tolist() - results = self._generate_results(database_handler=database_handler, + results = self._generate_results(database_repos=database_repos, resfinder_database=resfinder_database, pointfinder_database=pointfinder_database, nprocs=args.nprocs, diff --git a/staramr/tests/integration/data/23S-A2075G.fsa b/staramr/tests/integration/data/23S-A2075G.fsa new file mode 100644 index 00000000..784ef0c6 --- /dev/null +++ b/staramr/tests/integration/data/23S-A2075G.fsa @@ -0,0 +1,43 @@ +>23S (A2075G) +AGCTACTAAGAGCGAATGGTGGATGCCTTGACTGGTAAAGGCGATGAAGGACGTACTAGACTGCGATAAG +CTACGGGGAGCTGTCAAGAAGCTTTGATCCGTAGATTTCCGAATGGGGCAACCCAATGTATAGAGATATA +CATTACCTATATAGGAGCGAACGAGGGGAATTGAAACATCTTAGTACCCTCAGGAAAAGAAATCAATAGA +GATTGCGTCAGTAGCGGCGAGCGAAAGCGCAAGAGGGCAAACCCAGTGCTTGCACTGGGGGTTGTAGGAC +TGCAATGTGCAAGAGCTGAGTTTAGCAGAACATTCTGGAAAGTATAGCCATAGAGGGTGATAGTCCCGTA +TGCGAAAAACAAAGCTTAGCTAGCAGTATCCTGAGTAGGGCGGGACACGAGGAATCCTGTCTGAATCCGG +GTCGACCACGATCCAACCCTAAATACTAATACCAGATCGATAGTGCACAAGTACCGTGAGGGAAAGGTGA +AAAGAACTGAGGTGATCAGAGTGAAATAGAACCTGAAACCATTTGCTTACAATCATTCAGAGCACTATGT +AGCAATACAGTGTGATGGACTGCCTTTTGCATAATGAGCCTGCGAGTTGTGGTGTCTGGCAAGGTTAAGC +AAACGCGAAGCCGTAGCGAAAGCGAGTCTGAATAGGGCGCTTAGTCAGATGCTGCAGACCCGAAACGAAG +TGATCTATCCATGAGCAAGTTGAAGCTAGTGTAAGAACTAGTGGAGGACTGAACCCATAGGCGTTGAAAA +GCCCCGGGATGACTTGTGGATAGGGGTGAAAGGCCAATCAAACTTCGTGATAGCTGGTTCTCTCCGAAAT +ATATTTAGGTATAGCGTTGTGTCGTAATATAAGGGGGTAGAGCACTGAATGGGCTAGGGCATACACCAAT +GTACCAAACCCTATCAAACTCCGAATACCTTATATGTAATCACAGCAGTCAGGCGGCGAGTGATAAAATC +CGTCGTCAAGAGGGAAACAACCCAGACTACCAGCTAAGGTCCCTAAATCTTACTTAAGTGGAAAACGATG +TGAAGTTACTTAAACAACCAGGAGGTTGGCTTAGAAGCAGCCATCCTTTAAAGAAAGCGTAATAGCTCAC +TGGTCTAGTGATTTTGCGCGGAAAATATAACGGGGCTAAAGTAAGTACCGAAGCTGTAGACTTAGTTTAC +TAAGTGGTAGGAGAGCGTTCTATTTGCGTCGAAGGTATACCGGTAAGGAGTGCTGGAGCGAATAGAAGTG +AGCATGCAGGCATGAGTAGCGATAATTAATGTGAGAATCATTAACGCCGTAAACCCAAGGTTTCCTACGC +GATGCTCGTCATCGTAGGGTTAGTCGGGTCCTAAGTCGAGTCCGAAAGGGGTAGACGATGGCAAATTGGT +TAATATTCCAATACCAACATTAGTGTGCGATGGAAGGACGCTTAGGGCTAAGGGGGCTAGCGGATGGAAG +TGCTAGTCTAAGGTCGTAGGAGGTTATACAGGCAAATCCGTATAACAATACTCCGAGAACTGAAAGGCTT +TTTGAAGTCTTCGGATGGATAGAAGAACCCCTGATGCCGTCGAGCCAAGAAAAGTTTCTAAGTTTAGCTA +ATGTTGCCCGTACCGTAAACCGACACAGGTGGGTGGGATGAGTATTCTAAGGCGCGTGGAAGAACTCTCT +TTAAGGAACTCTGCAAAATAGCACCGTATCTTCGGTATAAGGTGTGGTTAGCTTTGTATTAGGATTTACT +CTGAAAGCAAGGAAACTTACAACAAAGAGTCCCTCCCGACTGTTTACCAAAAACACAGCACTCTGCTAAC +TCGTAAGAGGATGTATAGGGTGTGACGCCTGCCCGGTGCTCGAAGGTTAATTGATGGGGTTAGCATTAGC +GAAGCTCTTGATCGAAGCCCGAGTAAACGGCGGCCGTAACTATAACGGTCCTAAGGTAGCGAAATTCCTT +GTCGGTTAAATACCGACCTGCATGAATGGCGTAACGAGATGGGAGCTGTCTCAAAGAGGGATCCAGTGAA +ATTGTAGTGGAGGTGAAAATTCCTCCTACCCGCGGCAAGACGGAgAGACCCCGTGGACCTTTACTACAGC +TTGACACTGCTACTTGGATAAGAATGTGCAGGATAGGTGGGAGGCTTTGAGTATATGACGCCAGTTGTAT +ATGAGCCATTGTTGAGATACCACTCTTTCTTATTTGGGTAGCTAACCAGCTTGAGTTATCCTCAAGTGGG +ACAATGTCTGGTGGGTAGTTTGACTGGGGCGGTCGCCTCCCAAATAATAACGGAGGCTTACAAAGGTTGG +CTCAGAACGGTTGGAAATCGTTCGTAGAGTATAAAGGTATAAGCCAGCTTAACTGCAAGACATACAAGTC +AAGCAGAGACGAAAGTCGGTCTTAGTGATCCGGTGGTTCTGTGTGGAAGGGCCATCGCTCAAAGGATAAA +AGGTACCCCGGGGATAACAGGCTGATCTCCCCCAAGAGCTCACATCGACGGGGAGGTTTGGCACCTCGAT +GTCGGCTCATCGCATCCTGGGGCTGGAGCAGGTCCCAAGGGTATGGCTGTTCGCCATTTAAAGCGGTACG +CGAGCTGGGTTCAGAACGTCGTGAGACAGTTCGGTCCCTATCTGCCGTGGGCGTAAGAAGATTGAAGAGA +TTTGACCCTAGTACGAGAGGACCGGGTTGAACAAACCACTGGTGTAGCTGTTGTTCTGCCAAGAGCATCG +CAGCGTAGCTAAGTTTGGAAAGGATAAACGCTGAAAGCATCTAAGCGTGAAGCCAACTCTAAGATGAATC +TTCTCTAAGCTCTCTAGAAGACTACTAGTTTGATAGGCTGGGTGTGTAATGGATGAAAGTCCTTTAGCTG +ACCAGTACTAATAGAGCGTTTGGCTTATCTTTAATAAAGCAT diff --git a/staramr/tests/integration/data/gyrA-A70T.fsa b/staramr/tests/integration/data/gyrA-A70T.fsa new file mode 100644 index 00000000..c5242d11 --- /dev/null +++ b/staramr/tests/integration/data/gyrA-A70T.fsa @@ -0,0 +1,39 @@ +>gyrA G208A A70T +ATGGAGAATATTTTTAGCAAAGATTCTGATATTGAACTTGTAGATATAGAAAATTCTATAAAAAGTAGTT +ATTTAGACTATTCTATGAGTGTTATTATAGGTCGTGCTTTGCCTGACGCAAGAGATGGTTTAAAGCCTGT +TCATAGAAGAATTTTATATGCTATGCAAAATGATGAGGCAAAAAGTAGAACAGATTTTGTCAAATCAaCC +CGTATAGTGGGTGCTGTTATAGGTCGTTATCACCCACATGGAGATACAGCAGTTTATGATGCTTTGGTTA +GAATGGCTCAAGATTTTTCTATGAGATATCCAAGTATTACAGGACAAGGCAACTTTGGATCTATAGATGG +TGATAGTGCCGCTGCGATGCGTTATACTGAAGCAAAAATGAGTAAACTTTCTCATGAGCTTTTAAAAGAT +ATAGATAAAGATACGGTCGATTTTGTTCCAAATTATGATGGTTCAGAAAGCGAACCTGATGTTTTACCTT +CTAGGGTTCCAAATTTATTATTAAATGGTTCAAGTGGTATAGCTGTAGGTATGGCGACAAACATCCCACC +TCATAGTTTAAATGAGTTGATAGATGGACTTTTATATTTGCTTGATAATAAAGATGCAAGCCTAGAAGAG +ATTATGCAGTTTATCAAAGGTCCAGATTTTCCAACAGGTGGAATAATTTATGGTAAAAAAGGTATTATAG +AAGCTTATCGCACAGGGCGTGGTCGCGTGAAAGTGCGAGCTAAAACTCATATTGAAAAAAAGACAAATAA +AGATGTTATTGTTATCGATGAGCTTCCTTATCAAACCAATAAAGCTAGGCTTATAGAGCAGATTGCAGAG +CTTGTTAAAGAAAGGCAAATTGAAGGAATATCTGAAGTAAGAGATGAGAGCAATAAAGAAGGAATCCGCG +TTGTTATAGAGCTTAAACGTGAGGCTATGAGTGAAATTGTTTTAAATAATCTATTTAAATCTACCACTAT +GGAAAGTACTTTTGGTGTGATTATGTTGGCAATTCATAATAAAGAACCTAAAATTTTCTCTTTGTTGGAA +CTTTTAAATCTTTTCTTAACTCATAGAAAAACAGTTATTATTAGAAGAACGATTTTTGAACTTCAAAAGG +CAAGAGCAAGAGCTCATATTTTAGAAGGTCTTAAAATTGCACTTGATAATATAGATGAAGTGATTGCTTT +AATTAAAAATAGTTCTGATAATAATACCGCAAGAGATTCTTTAGTAGCTAAATTTGGTCTTAGTGAGCTT +CAAGCCAATGCTATTTTAGATATGAAACTTGGTCGTTTAACAGGACTTGAAAGAGAAAAAATCGAAAATG +AACTTGCAGAATTAATGAAAGAAATTGCAAGACTTGAAGAAATTTTAAAAAGTGAAACCTTGCTTGAAAA +TTTAATTCGCGATGAATTAAAAGAAATTAGAAGTAAATTTGATGTGCCACGTATTACTCAAATTGAAGAT +GATTACGATGATATTGATATTGAAGATTTGATTCCTAATGAAAATATGGTTGTAACTATCACACATCGTG +GTTATATTAAGCGTGTGCCTAGTAAACAATATGAAAAACAAAAACGAGGTGGAAAAGGAAAATTAGCCGT +TACGACTTATGATGATGATTTTATAGAAAGTTTCTTTACGGCAAATACACATGATACGCTTATGTTTGTA +ACAGATCGTGGACAGCTTTATTGGCTTAAAGTTTATAAAATTCCTGAAGGCTCAAGAACGGCTAAAGGAA +AAGCAGTGGTAAATCTTATCAATTTACAAGCTGAAGAAAAAATCATGGCTATTATTCCAACCACGGATTT +TGATGAGAGCAAATCTTTATGTTTCTTTACTAAAAATGGTATTGTAAAGCGTACAAATTTGAGTGAATAT +CAAAATATCAGAAGTGTAGGAGTTAGAGCGATCAACTTGGATGAAAATGATGAGTTGGTAACTGCTATTA +TTGTTCAAAGAGATGAAGATGAAATTTTTGCCACTGGTGGTGAAGAAAATTTAGAAAATCAAGAAATTGA +AAATTTAGATGATGAAAATCTTGAAAATGAAGAAAGTGTAAGCACACAAGGTAAAATGCTCTTTGCAGTA +ACCAAAAAAGGTATGTGTATCAAATTCCCACTTGCTAAAGTGCGTGAAATCGGCCGTGTAAGTCGTGGGG +TGACGGCTATTAAGTTTAAAGAGAAAAATGACGAATTAGTAGGTGCAGTTGTTATAGAAAATGATGAGCA +AGAAATTTTAAGCATAAGTGCAAAAGGTATAGGAAAACGCACCAATGCTGGAGAATATAGATTGCAAAGC +AGAGGTGGTAAGGGTGTAATTTGTATGAAACTTACAGAAAAAACCAAAGATCTTATTAGCGTAGTTATAG +TAGATGAAACTATGGATTTAATGGCTCTTACAAGTTCAGGTAAGATGATACGTGTTGATATGCAAAGCAT +TAGAAAAGCAGGGCGTAATACGAGTGGTGTCATTGTAGTTAATGTGGAAAATGACGAGGTGGTTAGCATC +GCTAAGTGTCCTAAAGAGGAAAATGACGAGGATGAGTTAAGCGATGAAAACTTTGGTTTAGATTTGCAAT +AA diff --git a/staramr/tests/integration/databases/test_AMRDatabasesManager.py b/staramr/tests/integration/databases/test_AMRDatabasesManager.py index 158163eb..ed54d286 100644 --- a/staramr/tests/integration/databases/test_AMRDatabasesManager.py +++ b/staramr/tests/integration/databases/test_AMRDatabasesManager.py @@ -3,8 +3,6 @@ import unittest from os import path -from staramr.databases.AMRDatabaseHandler import AMRDatabaseHandler -from staramr.databases.AMRDatabaseHandlerStripGitDir import AMRDatabaseHandlerStripGitDir from staramr.databases.AMRDatabasesManager import AMRDatabasesManager @@ -19,44 +17,40 @@ def setUp(self): def tearDown(self): self.databases_dir.cleanup() - def testGetHandlerGitStripDir(self): - self.assertIsInstance(self.databases_manager.get_database_handler(), AMRDatabaseHandlerStripGitDir, - 'Invalid instance returned') - - def testGetHandlerGit(self): - self.assertIsInstance(self.databases_manager.get_database_handler(force_use_git=True), AMRDatabaseHandler, - 'Invalid instance returned') - def testSetupDefault(self): - database_handler = self.databases_manager.get_database_handler() + blast_database_repos = self.databases_manager.get_database_repos() # Verify that databases don't exist beforehand - self.assertFalse(path.exists(database_handler.get_resfinder_dir()), + self.assertFalse(path.exists(blast_database_repos.get_repo_dir('resfinder')), 'resfinder path exists before creation of database') - self.assertFalse(path.exists(database_handler.get_pointfinder_dir()), + self.assertFalse(path.exists(blast_database_repos.get_repo_dir('pointfinder')), 'pointfinder path exists before creation of database') # Setup default database self.databases_manager.setup_default() # Verify that resfinder/pointfinder paths exist - self.assertTrue(path.exists(database_handler.get_resfinder_dir()), 'resfinder path does not exist') - self.assertTrue(path.exists(database_handler.get_resfinder_dir()), 'pointfinder path does not exist') - self.assertTrue(path.exists(path.join(database_handler.get_database_dir(), 'info.ini')), - 'info file does not exist') + self.assertTrue(path.exists(blast_database_repos.get_repo_dir('resfinder')), 'resfinder path does not exist') + self.assertTrue(path.exists(blast_database_repos.get_repo_dir('resfinder')), 'pointfinder path does not exist') + self.assertTrue(path.exists(path.join(blast_database_repos.get_database_dir(), 'resfinder-info.ini')), + 'resfinder info file does not exist') + self.assertTrue(path.exists(path.join(blast_database_repos.get_database_dir(), 'pointfinder-info.ini')), + 'pointfinder info file does not exist') # Verify we've removed the .git directories - self.assertFalse(path.exists(path.join(database_handler.get_resfinder_dir(), '.git')), + self.assertFalse(path.exists(path.join(blast_database_repos.get_repo_dir('resfinder'), '.git')), 'resfinder .git directory was not removed') - self.assertFalse(path.exists(path.join(database_handler.get_pointfinder_dir(), '.git')), + self.assertFalse(path.exists(path.join(blast_database_repos.get_repo_dir('pointfinder'), '.git')), 'pointfinder .git directory was not removed') config = configparser.ConfigParser() - config.read(path.join(database_handler.get_database_dir(), 'info.ini')) + config.read(path.join(blast_database_repos.get_database_dir(), 'resfinder-info.ini')) # Verify that the info.ini file has correct git commits for default database self.assertEqual(config['GitInfo']['resfinder_db_commit'], self.RESFINDER_DEFAULT_COMMIT, 'invalid resfinder commit') + + config.read(path.join(blast_database_repos.get_database_dir(), 'pointfinder-info.ini')) self.assertEqual(config['GitInfo']['pointfinder_db_commit'], self.POINTFINDER_DEFAULT_COMMIT, 'invalid pointfinder commit') @@ -65,40 +59,44 @@ def testRestoreDefault(self): self.databases_manager.setup_default() # Build updated database - database_handler_git = self.databases_manager.get_database_handler(force_use_git=True) - database_handler_git.build(resfinder_commit=self.RESFINDER_DEFAULT_COMMIT, - pointfinder_commit=self.POINTFINDER_DEFAULT_COMMIT) + blast_database_repos_git = self.databases_manager.get_database_repos(force_use_git=True) + blast_database_repos_git.build( + {'resfinder': self.RESFINDER_DEFAULT_COMMIT, 'pointfinder': self.POINTFINDER_DEFAULT_COMMIT}) # Verify that updated database is the one that gets returned by get_database_handler() - database_handler = self.databases_manager.get_database_handler() - self.assertIsInstance(database_handler, AMRDatabaseHandler, 'Invalid instance returned') - self.assertTrue(path.exists(path.join(database_handler.get_resfinder_dir(), '.git')), + blast_database_repos = self.databases_manager.get_database_repos() + self.assertFalse(blast_database_repos.is_dist(), 'Invalid is_dist') + self.assertEqual(blast_database_repos.get_database_dir(), path.join(self.databases_dir.name, 'update'), + 'Invalid database directory') + self.assertTrue(path.exists(path.join(blast_database_repos.get_repo_dir('resfinder'), '.git')), 'Not using git version (updated version) of resfinder database') - self.assertTrue(path.exists(path.join(database_handler.get_pointfinder_dir(), '.git')), + self.assertTrue(path.exists(path.join(blast_database_repos.get_repo_dir('pointfinder'), '.git')), 'Not using git version (updated version) of pointfinder database') # Restore default database self.databases_manager.restore_default() # Verify that default database (git stripped version) is the one that gets returned by get_database_handler() - database_handler = self.databases_manager.get_database_handler() - self.assertIsInstance(database_handler, AMRDatabaseHandlerStripGitDir, 'Invalid instance returned') - self.assertFalse(path.exists(path.join(database_handler.get_resfinder_dir(), '.git')), + blast_database_repos = self.databases_manager.get_database_repos() + self.assertTrue(blast_database_repos.is_dist(), 'Invalid is_dist') + self.assertEqual(blast_database_repos.get_database_dir(), path.join(self.databases_dir.name, 'dist'), + 'Invalid database directory') + self.assertFalse(path.exists(path.join(blast_database_repos.get_repo_dir('resfinder'), '.git')), 'resfinder .git directory was not removed') - self.assertFalse(path.exists(path.join(database_handler.get_pointfinder_dir(), '.git')), + self.assertFalse(path.exists(path.join(blast_database_repos.get_repo_dir('pointfinder'), '.git')), 'pointfinder .git directory was not removed') def testIsHandlerDefaultCommitsTrue(self): # Setup default database self.databases_manager.setup_default() - database_handler = self.databases_manager.get_database_handler() + blast_database_repos = self.databases_manager.get_database_repos() - self.assertTrue(AMRDatabasesManager.is_handler_default_commits(database_handler), "Database is not default") + self.assertTrue(AMRDatabasesManager.is_database_repos_default_commits(blast_database_repos), "Database is not default") def testIsHandlerDefaultCommitsFalse(self): # Setup database - database_handler = self.databases_manager.get_database_handler(force_use_git=True) - database_handler.update(resfinder_commit='dc33e2f9ec2c420f99f77c5c33ae3faa79c999f2') + blast_database_repos = self.databases_manager.get_database_repos(force_use_git=True) + blast_database_repos.update({'resfinder': 'dc33e2f9ec2c420f99f77c5c33ae3faa79c999f2'}) - self.assertFalse(AMRDatabasesManager.is_handler_default_commits(database_handler), "Database is default") + self.assertFalse(AMRDatabasesManager.is_database_repos_default_commits(blast_database_repos), "Database is default") diff --git a/staramr/tests/integration/databases/test_AMRDatabaseHandler.py b/staramr/tests/integration/databases/test_BlastDatabaseRepositories.py similarity index 53% rename from staramr/tests/integration/databases/test_AMRDatabaseHandler.py rename to staramr/tests/integration/databases/test_BlastDatabaseRepositories.py index 52862cd8..61e6be66 100644 --- a/staramr/tests/integration/databases/test_AMRDatabaseHandler.py +++ b/staramr/tests/integration/databases/test_BlastDatabaseRepositories.py @@ -4,10 +4,10 @@ import git -from staramr.databases.AMRDatabaseHandler import AMRDatabaseHandler +from staramr.databases.BlastDatabaseRepositories import BlastDatabaseRepositories -class AMRDatabaseHandlerIT(unittest.TestCase): +class BlastDatabaseRepositoriesIT(unittest.TestCase): RESFINDER_VALID_COMMIT = 'dc33e2f9ec2c420f99f77c5c33ae3faa79c999f2' RESFINDER_VALID_COMMIT2 = 'a4a699f3d13974477c7120b98fb0c63a1b70bd16' POINTFINDER_VALID_COMMIT = 'ba65c4d175decdc841a0bef9f9be1c1589c0070a' @@ -15,59 +15,59 @@ class AMRDatabaseHandlerIT(unittest.TestCase): def setUp(self): self.databases_dir = tempfile.TemporaryDirectory() - self.database_handler = AMRDatabaseHandler(database_dir=self.databases_dir.name) + self.database_repositories = BlastDatabaseRepositories.create_default_repositories(self.databases_dir.name) def tearDown(self): self.databases_dir.cleanup() def testBuild(self): # Verify that databases don't exist beforehand - self.assertFalse(path.exists(self.database_handler.get_resfinder_dir()), + self.assertFalse(path.exists(self.database_repositories.get_repo_dir('resfinder')), 'resfinder path exists before creation of database') - self.assertFalse(path.exists(self.database_handler.get_pointfinder_dir()), + self.assertFalse(path.exists(self.database_repositories.get_repo_dir('pointfinder')), 'pointfinder path exists before creation of database') # Build database - self.database_handler.build(resfinder_commit=self.RESFINDER_VALID_COMMIT, - pointfinder_commit=self.POINTFINDER_VALID_COMMIT) + self.database_repositories.build( + {'resfinder': self.RESFINDER_VALID_COMMIT, 'pointfinder': self.POINTFINDER_VALID_COMMIT}) # Verify database is built properly - self.assertTrue(path.exists(self.database_handler.get_resfinder_dir()), + self.assertTrue(path.exists(self.database_repositories.get_repo_dir('resfinder')), 'No resfinder dir') - self.assertTrue(path.exists(self.database_handler.get_pointfinder_dir()), + self.assertTrue(path.exists(self.database_repositories.get_repo_dir('pointfinder')), 'No pointfinder dir') # Verify correct commits - resfinder_repo_head = git.Repo(self.database_handler.get_resfinder_dir()).commit('HEAD') + resfinder_repo_head = git.Repo(self.database_repositories.get_repo_dir('resfinder')).commit('HEAD') self.assertEqual(str(resfinder_repo_head), self.RESFINDER_VALID_COMMIT, 'Resfinder commits invalid') - pointfinder_repo_head = git.Repo(self.database_handler.get_pointfinder_dir()).commit('HEAD') + pointfinder_repo_head = git.Repo(self.database_repositories.get_repo_dir('pointfinder')).commit('HEAD') self.assertEqual(str(pointfinder_repo_head), self.POINTFINDER_VALID_COMMIT, 'Pointfinder commits invalid') def testUpdate(self): # Build database - self.database_handler.build(resfinder_commit=self.RESFINDER_VALID_COMMIT, - pointfinder_commit=self.POINTFINDER_VALID_COMMIT) + self.database_repositories.build( + {'resfinder': self.RESFINDER_VALID_COMMIT, 'pointfinder': self.POINTFINDER_VALID_COMMIT}) # Update database - self.database_handler.update(resfinder_commit=self.RESFINDER_VALID_COMMIT2, - pointfinder_commit=self.POINTFINDER_VALID_COMMIT2) + self.database_repositories.update( + {'resfinder': self.RESFINDER_VALID_COMMIT2, 'pointfinder': self.POINTFINDER_VALID_COMMIT2}) # Verify correct commits - resfinder_repo_head = git.Repo(self.database_handler.get_resfinder_dir()).commit('HEAD') + resfinder_repo_head = git.Repo(self.database_repositories.get_repo_dir('resfinder')).commit('HEAD') self.assertEqual(str(resfinder_repo_head), self.RESFINDER_VALID_COMMIT2, 'Resfinder commits invalid') - pointfinder_repo_head = git.Repo(self.database_handler.get_pointfinder_dir()).commit('HEAD') + pointfinder_repo_head = git.Repo(self.database_repositories.get_repo_dir('pointfinder')).commit('HEAD') self.assertEqual(str(pointfinder_repo_head), self.POINTFINDER_VALID_COMMIT2, 'Pointfinder commits invalid') def testInfo(self): # Build database - self.database_handler.build(resfinder_commit=self.RESFINDER_VALID_COMMIT, - pointfinder_commit=self.POINTFINDER_VALID_COMMIT) + self.database_repositories.build( + {'resfinder': self.RESFINDER_VALID_COMMIT, 'pointfinder': self.POINTFINDER_VALID_COMMIT}) - database_info = self.database_handler.info() + database_info = self.database_repositories.info() # Verify correct commits in info self.assertEqual(database_info['resfinder_db_commit'], self.RESFINDER_VALID_COMMIT, diff --git a/staramr/tests/integration/detection/test_AMRDetection.py b/staramr/tests/integration/detection/test_AMRDetection.py index f6b13bf3..a59ba5c2 100644 --- a/staramr/tests/integration/detection/test_AMRDetection.py +++ b/staramr/tests/integration/detection/test_AMRDetection.py @@ -22,16 +22,17 @@ class AMRDetectionIT(unittest.TestCase): def setUp(self): - database_handler = AMRDatabasesManager.create_default_manager().get_database_handler() - self.resfinder_dir = database_handler.get_resfinder_dir() - self.pointfinder_dir = database_handler.get_pointfinder_dir() + blast_databases_repositories = AMRDatabasesManager.create_default_manager().get_database_repos() + self.resfinder_dir = blast_databases_repositories.get_repo_dir('resfinder') + self.pointfinder_dir = blast_databases_repositories.get_repo_dir('pointfinder') self.resfinder_database = ResfinderBlastDatabase(self.resfinder_dir) self.resfinder_drug_table = ARGDrugTableResfinder() self.pointfinder_drug_table = ARGDrugTablePointfinder() self.pointfinder_database = None self.blast_out = tempfile.TemporaryDirectory() - self.blast_handler = BlastHandler(self.resfinder_database, 2, self.blast_out.name, self.pointfinder_database) + self.blast_handler = BlastHandler( + {'resfinder': self.resfinder_database, 'pointfinder': self.pointfinder_database}, 2, self.blast_out.name) self.outdir = tempfile.TemporaryDirectory() self.amr_detection = AMRDetectionResistance(self.resfinder_database, self.resfinder_drug_table, @@ -393,7 +394,8 @@ def testResfinderBetaLactamTwoCopiesOneReverseComplement(self): def testPointfinderSalmonellaA67PSuccess(self): pointfinder_database = PointfinderBlastDatabase(self.pointfinder_dir, 'salmonella') - blast_handler = BlastHandler(self.resfinder_database, 2, self.blast_out.name, pointfinder_database) + blast_handler = BlastHandler({'resfinder': self.resfinder_database, 'pointfinder': pointfinder_database}, 2, + self.blast_out.name) amr_detection = AMRDetectionResistance(self.resfinder_database, self.resfinder_drug_table, blast_handler, self.pointfinder_drug_table, pointfinder_database, output_dir=self.outdir.name) @@ -427,7 +429,8 @@ def testPointfinderSalmonellaA67PSuccess(self): def testPointfinderSalmonellaA67PSuccessNoPhenotype(self): pointfinder_database = PointfinderBlastDatabase(self.pointfinder_dir, 'salmonella') - blast_handler = BlastHandler(self.resfinder_database, 2, self.blast_out.name, pointfinder_database) + blast_handler = BlastHandler({'resfinder': self.resfinder_database, 'pointfinder': pointfinder_database}, 2, + self.blast_out.name) amr_detection = AMRDetection(self.resfinder_database, blast_handler, pointfinder_database, output_dir=self.outdir.name) @@ -459,7 +462,8 @@ def testPointfinderSalmonellaA67PSuccessNoPhenotype(self): def testPointfinderSalmonellaA67PDelEndSuccess(self): pointfinder_database = PointfinderBlastDatabase(self.pointfinder_dir, 'salmonella') - blast_handler = BlastHandler(self.resfinder_database, 2, self.blast_out.name, pointfinder_database) + blast_handler = BlastHandler({'resfinder': self.resfinder_database, 'pointfinder': pointfinder_database}, 2, + self.blast_out.name) amr_detection = AMRDetectionResistance(self.resfinder_database, self.resfinder_drug_table, blast_handler, self.pointfinder_drug_table, pointfinder_database, output_dir=self.outdir.name) @@ -493,7 +497,8 @@ def testPointfinderSalmonellaA67PDelEndSuccess(self): def testPointfinderSalmonellaA67PDelEndFailPlength(self): pointfinder_database = PointfinderBlastDatabase(self.pointfinder_dir, 'salmonella') - blast_handler = BlastHandler(self.resfinder_database, 2, self.blast_out.name, pointfinder_database) + blast_handler = BlastHandler({'resfinder': self.resfinder_database, 'pointfinder': pointfinder_database}, 2, + self.blast_out.name) amr_detection = AMRDetectionResistance(self.resfinder_database, self.resfinder_drug_table, blast_handler, self.pointfinder_drug_table, pointfinder_database, output_dir=self.outdir.name) @@ -509,7 +514,8 @@ def testPointfinderSalmonellaA67PDelEndFailPlength(self): def testPointfinderSalmonellaA67PFailPID(self): pointfinder_database = PointfinderBlastDatabase(self.pointfinder_dir, 'salmonella') - blast_handler = BlastHandler(self.resfinder_database, 2, self.blast_out.name, pointfinder_database) + blast_handler = BlastHandler({'resfinder': self.resfinder_database, 'pointfinder': pointfinder_database}, 2, + self.blast_out.name) amr_detection = AMRDetectionResistance(self.resfinder_database, self.resfinder_drug_table, blast_handler, self.pointfinder_drug_table, pointfinder_database, output_dir=self.outdir.name) @@ -524,7 +530,8 @@ def testPointfinderSalmonellaA67PFailPID(self): def testPointfinderSalmonellaA67TFail(self): pointfinder_database = PointfinderBlastDatabase(self.pointfinder_dir, 'salmonella') - blast_handler = BlastHandler(self.resfinder_database, 2, self.blast_out.name, pointfinder_database) + blast_handler = BlastHandler({'resfinder': self.resfinder_database, 'pointfinder': pointfinder_database}, 2, + self.blast_out.name) amr_detection = AMRDetectionResistance(self.resfinder_database, self.resfinder_drug_table, blast_handler, self.pointfinder_drug_table, pointfinder_database, output_dir=self.outdir.name) @@ -539,7 +546,8 @@ def testPointfinderSalmonellaA67TFail(self): def testPointfinderSalmonellaA67PReverseComplementSuccess(self): pointfinder_database = PointfinderBlastDatabase(self.pointfinder_dir, 'salmonella') - blast_handler = BlastHandler(self.resfinder_database, 2, self.blast_out.name, pointfinder_database) + blast_handler = BlastHandler({'resfinder': self.resfinder_database, 'pointfinder': pointfinder_database}, 2, + self.blast_out.name) amr_detection = AMRDetectionResistance(self.resfinder_database, self.resfinder_drug_table, blast_handler, self.pointfinder_drug_table, pointfinder_database, output_dir=self.outdir.name) @@ -573,7 +581,8 @@ def testPointfinderSalmonellaA67PReverseComplementSuccess(self): def testPointfinderSalmonella_16S_rrSD_C1065T_Success(self): pointfinder_database = PointfinderBlastDatabase(self.pointfinder_dir, 'salmonella') - blast_handler = BlastHandler(self.resfinder_database, 2, self.blast_out.name, pointfinder_database) + blast_handler = BlastHandler({'resfinder': self.resfinder_database, 'pointfinder': pointfinder_database}, 2, + self.blast_out.name) amr_detection = AMRDetectionResistance(self.resfinder_database, self.resfinder_drug_table, blast_handler, self.pointfinder_drug_table, pointfinder_database, output_dir=self.outdir.name) @@ -608,7 +617,8 @@ def testPointfinderSalmonella_16S_rrSD_C1065T_Success(self): def testResfinderPointfinderSalmonella_16S_C1065T_gyrA_A67_beta_lactam_Success(self): pointfinder_database = PointfinderBlastDatabase(self.pointfinder_dir, 'salmonella') - blast_handler = BlastHandler(self.resfinder_database, 2, self.blast_out.name, pointfinder_database) + blast_handler = BlastHandler({'resfinder': self.resfinder_database, 'pointfinder': pointfinder_database}, 2, + self.blast_out.name) amr_detection = AMRDetectionResistance(self.resfinder_database, self.resfinder_drug_table, blast_handler, self.pointfinder_drug_table, pointfinder_database, output_dir=self.outdir.name) @@ -673,7 +683,8 @@ def testResfinderPointfinderSalmonella_16S_C1065T_gyrA_A67_beta_lactam_Success(s def testResfinderPointfinderSalmonellaExcludeGenesListSuccess(self): pointfinder_database = PointfinderBlastDatabase(self.pointfinder_dir, 'salmonella') - blast_handler = BlastHandler(self.resfinder_database, 2, self.blast_out.name, pointfinder_database) + blast_handler = BlastHandler({'resfinder': self.resfinder_database, 'pointfinder': pointfinder_database}, 2, + self.blast_out.name) amr_detection = AMRDetectionResistance(self.resfinder_database, self.resfinder_drug_table, blast_handler, self.pointfinder_drug_table, pointfinder_database, output_dir=self.outdir.name, genes_to_exclude=['gyrA']) @@ -696,7 +707,8 @@ def testResfinderPointfinderSalmonellaExcludeGenesListSuccess(self): def testResfinderPointfinderSalmonella_16Src_C1065T_gyrArc_A67_beta_lactam_Success(self): pointfinder_database = PointfinderBlastDatabase(self.pointfinder_dir, 'salmonella') - blast_handler = BlastHandler(self.resfinder_database, 2, self.blast_out.name, pointfinder_database) + blast_handler = BlastHandler({'resfinder': self.resfinder_database, 'pointfinder': pointfinder_database}, 2, + self.blast_out.name) amr_detection = AMRDetectionResistance(self.resfinder_database, self.resfinder_drug_table, blast_handler, self.pointfinder_drug_table, pointfinder_database, output_dir=self.outdir.name) @@ -825,6 +837,73 @@ def testNonMatches(self): self.assertEqual(len(os.listdir(self.outdir.name)), 0, 'File found where none should exist') + def testPointfinderCampylobacterA70TSuccess(self): + pointfinder_database = PointfinderBlastDatabase(self.pointfinder_dir, 'campylobacter') + blast_handler = BlastHandler({'resfinder': self.resfinder_database, 'pointfinder': pointfinder_database}, 2, + self.blast_out.name) + amr_detection = AMRDetectionResistance(self.resfinder_database, self.resfinder_drug_table, blast_handler, + self.pointfinder_drug_table, pointfinder_database, + output_dir=self.outdir.name) + + file = path.join(self.test_data_dir, "gyrA-A70T.fsa") + files = [file] + amr_detection.run_amr_detection(files, 99, 99, 90) + + pointfinder_results = amr_detection.get_pointfinder_results() + self.assertEqual(len(pointfinder_results.index), 1, 'Wrong number of rows in result') + + result = pointfinder_results[pointfinder_results['Gene'] == 'gyrA (A70T)'] + self.assertEqual(len(result.index), 1, 'Wrong number of results detected') + self.assertEqual(result.index[0], 'gyrA-A70T', msg='Wrong file') + self.assertEqual(result['Type'].iloc[0], 'codon', msg='Wrong type') + self.assertEqual(result['Position'].iloc[0], 70, msg='Wrong codon position') + self.assertEqual(result['Mutation'].iloc[0], 'GCC -> ACC (A -> T)', msg='Wrong mutation') + self.assertAlmostEqual(result['%Identity'].iloc[0], 99.96, places=2, msg='Wrong pid') + self.assertAlmostEqual(result['%Overlap'].iloc[0], 100.00, places=2, msg='Wrong overlap') + self.assertEqual(result['HSP Length/Total Length'].iloc[0], '2592/2592', msg='Wrong lengths') + self.assertEqual(result['Predicted Phenotype'].iloc[0], 'ciprofloxacin I/R', 'Wrong phenotype') + + hit_file = path.join(self.outdir.name, 'pointfinder_gyrA-A70T.fsa') + records = SeqIO.to_dict(SeqIO.parse(hit_file, 'fasta')) + + self.assertEqual(len(records), 1, 'Wrong number of hit records') + + expected_records = SeqIO.to_dict(SeqIO.parse(file, 'fasta')) + self.assertEqual(expected_records['gyrA'].seq.upper(), records['gyrA'].seq.upper(), "records don't match") + + def testPointfinderCampylobacterA2075GSuccess(self): + pointfinder_database = PointfinderBlastDatabase(self.pointfinder_dir, 'campylobacter') + blast_handler = BlastHandler({'resfinder': self.resfinder_database, 'pointfinder': pointfinder_database}, 2, + self.blast_out.name) + amr_detection = AMRDetectionResistance(self.resfinder_database, self.resfinder_drug_table, blast_handler, + self.pointfinder_drug_table, pointfinder_database, + output_dir=self.outdir.name) + + file = path.join(self.test_data_dir, "23S-A2075G.fsa") + files = [file] + amr_detection.run_amr_detection(files, 99, 99, 90) + + pointfinder_results = amr_detection.get_pointfinder_results() + self.assertEqual(len(pointfinder_results.index), 1, 'Wrong number of rows in result') + + result = pointfinder_results[pointfinder_results['Gene'] == '23S (A2075G)'] + self.assertEqual(len(result.index), 1, 'Wrong number of results detected') + self.assertEqual(result.index[0], '23S-A2075G', msg='Wrong file') + self.assertEqual(result['Type'].iloc[0], 'nucleotide', msg='Wrong type') + self.assertEqual(result['Position'].iloc[0], 2075, msg='Wrong codon position') + self.assertEqual(result['Mutation'].iloc[0], 'A -> G', msg='Wrong mutation') + self.assertAlmostEqual(result['%Identity'].iloc[0], 99.97, places=2, msg='Wrong pid') + self.assertAlmostEqual(result['%Overlap'].iloc[0], 100.00, places=2, msg='Wrong overlap') + self.assertEqual(result['HSP Length/Total Length'].iloc[0], '2912/2912', msg='Wrong lengths') + self.assertEqual(result['Predicted Phenotype'].iloc[0], 'erythromycin, azithromycin, telithromycin, clindamycin', 'Wrong phenotype') + + hit_file = path.join(self.outdir.name, 'pointfinder_23S-A2075G.fsa') + records = SeqIO.to_dict(SeqIO.parse(hit_file, 'fasta')) + + self.assertEqual(len(records), 1, 'Wrong number of hit records') + + expected_records = SeqIO.to_dict(SeqIO.parse(file, 'fasta')) + self.assertEqual(expected_records['23S'].seq.upper(), records['23S'].seq.upper(), "records don't match") if __name__ == '__main__': unittest.main()