diff --git a/peaks2utr/__init__.py b/peaks2utr/__init__.py index 566ea7b..7e38dad 100644 --- a/peaks2utr/__init__.py +++ b/peaks2utr/__init__.py @@ -50,6 +50,7 @@ def prepare_argparser(): parser.add_argument('-f', '-force', '--force', action="store_true", help="overwrite outputs if they exist") parser.add_argument('-o', '--output', help="output filename. Defaults to .new.") parser.add_argument('--gtf', dest="gtf_out", action="store_true", help="output in GTF format (rather than default GFF3)") + parser.add_argument('--skip-validation', action="store_true", help="skip validation of input files") parser.add_argument('--keep-cache', action="store_true", help="keep cached files on run completion") parser.add_argument('--version', action='version', version='%(prog)s {version}'.format(version=version(__package__))) return parser @@ -99,6 +100,7 @@ async def _main(args): from .utils import cached, yield_from_process from .preprocess import BAMSplitter, call_peaks, create_db from .postprocess import merge_annotations, gt_gff3_sort, write_summary_stats + from .validation import matching_chr try: ################### @@ -171,6 +173,16 @@ async def _main(args): BroadPeaksList(broadpeak_fn=cached("forward_peaks.broadPeak"), strand="forward") + \ BroadPeaksList(broadpeak_fn=cached("reverse_peaks.broadPeak"), strand="reverse") + ################### + # Validation # + ################### + + if not args.skip_validation: + logging.info("Performing input file validation.") + if not matching_chr(db, args): + logging.error("No chromosome shared between GFF_IN and BAM_IN. Aborting.") + sys.exit(1) + ################### # Process peaks # ################### diff --git a/peaks2utr/annotations.py b/peaks2utr/annotations.py index 53407b2..cdb97a5 100644 --- a/peaks2utr/annotations.py +++ b/peaks2utr/annotations.py @@ -1,7 +1,6 @@ import logging import math import multiprocessing -import sqlite3 from tqdm import tqdm @@ -9,8 +8,8 @@ from .constants import AnnotationColour, STRAND_MAP from .collections import SPATTruncationPointsDict, ZeroCoverageIntervalsDict from .exceptions import AnnotationsError -from .models import UTR, FeatureDB -from .utils import Counter, Falsey, cached, features_dict_for_gene, iter_batches +from .models import UTR +from .utils import Counter, Falsey, cached, connect_db, features_dict_for_gene, iter_batches class NoNearbyFeatures(Falsey): @@ -46,10 +45,6 @@ def __enter__(self): def __exit__(self, type, value, traceback): self.pbar.close() - def _connect_db(self): - db = sqlite3.connect(self.db_path, check_same_thread=False) - return FeatureDB(db) - def _batch_annotate_strand(self, peaks_batch): """ Create multiprocessing Process to handle batch of peaks. Connect to sqlite3 db for each batch to prevent @@ -60,7 +55,7 @@ def _batch_annotate_strand(self, peaks_batch): for strand, symbol in STRAND_MAP.items(): truncation_points[symbol] = SPATTruncationPointsDict(json_fn=cached(strand + "_unmapped.json")) coverage_gaps[symbol] = ZeroCoverageIntervalsDict(bed_fn=cached(strand + "_coverage_gaps.bed")) - db = self._connect_db() + db = connect_db(self.db_path) return multiprocessing.Process(target=self._iter_peaks, args=(db, peaks_batch, truncation_points, coverage_gaps)) def _iter_peaks(self, db, peaks_batch, truncation_points, coverage_gaps): diff --git a/peaks2utr/preprocess.py b/peaks2utr/preprocess.py index 0772f36..04567a8 100644 --- a/peaks2utr/preprocess.py +++ b/peaks2utr/preprocess.py @@ -13,7 +13,7 @@ from .exceptions import EXCEPTIONS_MAP from .models import SoftClippedRead -from .utils import cached, consume_lines, filter_nested_dict, sum_nested_dicts, multiprocess_over_dict +from .utils import cached, consume_lines, filter_nested_dict, index_bam_file, sum_nested_dicts, multiprocess_over_dict from .constants import CACHE_DIR, LOG_DIR, STRAND_PYSAM_ARGS @@ -74,16 +74,11 @@ def split_read_groups(self): for bf in self.read_group_bams} self.spat_outputs_to_process = self.spat_outputs.copy() - def index_bam_file(self, bam_file): - if not os.path.isfile(cached(bam_file + '.bai')): - logging.info("Indexing %s." % bam_file) - pysam.index("-@", str(self.args.processors), bam_file) - def _get_max_reads_for_pbar(self): max_reads = 0 for bf in self.read_group_bams: if not os.path.isfile(self.spat_outputs[bf]): - self.index_bam_file(bf) + index_bam_file(bf, self.args.processors) idxstats = pysam.idxstats(bf).split('\n') num_reads = sum([int(chr.split("\t")[2]) + int(chr.split("\t")[3]) for chr in idxstats[:-1]]) if num_reads > max_reads: diff --git a/peaks2utr/utils.py b/peaks2utr/utils.py index 8cfcc33..75ac399 100644 --- a/peaks2utr/utils.py +++ b/peaks2utr/utils.py @@ -1,10 +1,15 @@ +import logging import multiprocessing import os.path from queue import Empty import resource +import sqlite3 + +import pysam from .constants import CACHE_DIR from .exceptions import EXCEPTIONS_MAP +from .models import FeatureDB class Falsey: @@ -42,6 +47,17 @@ def cached(filename): return os.path.join(CACHE_DIR, filename) +def connect_db(db_path): + db = sqlite3.connect(db_path, check_same_thread=False) + return FeatureDB(db) + + +def index_bam_file(bam_file, processors): + if not os.path.isfile(cached(bam_file + '.bai')): + logging.info("Indexing %s." % bam_file) + pysam.index("-@", str(processors), bam_file) + + async def consume_lines(pipe, log_file): """ Asynchronously write lines in pipe to log file. diff --git a/peaks2utr/validation.py b/peaks2utr/validation.py new file mode 100644 index 0000000..f4e2ec1 --- /dev/null +++ b/peaks2utr/validation.py @@ -0,0 +1,33 @@ +import logging + +import pysam + +from .utils import connect_db, index_bam_file + + +def matching_chr(db_path, args): + """ + Check seqids in BAM and GFF input files to ensure at least one matches. Returns bool. + """ + db = connect_db(db_path) + gff_chrs = {f.seqid for f in db.all_features()} + bam_chrs = set() + + index_bam_file(args.BAM_IN, args.processors) + samfile = pysam.AlignmentFile(args.BAM_IN, "rb", require_index=True) + + for chr in gff_chrs: + try: + samfile.fetch(chr) + except ValueError: + logging.warning("Chromosome {} from GFF_IN not found in BAM_IN.".format(chr)) + else: + bam_chrs.add(str(chr)) + if len(bam_chrs) > 1: + logging.warning( + """ + Chromosomes {} are present in both GFF_IN and BAM_IN. + Consider reducing both to a single chromosome to improve performance. + """.format(', '.join(bam_chrs)) + ) + return bool(bam_chrs) diff --git a/setup.cfg b/setup.cfg index 4557dc1..6d92015 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = peaks2utr -version = 1.2.3 +version = 1.2.4 author = William Haese-Hill author_email = william.haese-hill@glasgow.ac.uk description = A robust, parallelized Python CLI for annotating three_prime_UTR diff --git a/tests/test_validation.py b/tests/test_validation.py new file mode 100644 index 0000000..08b5b46 --- /dev/null +++ b/tests/test_validation.py @@ -0,0 +1,36 @@ +import os +import os.path +import unittest +from unittest.mock import MagicMock, patch + +import gffutils + +from peaks2utr import prepare_argparser +from peaks2utr.validation import matching_chr + +TEST_DIR = os.path.dirname(__file__) + + +class TestValidation(unittest.TestCase): + def setUp(self): + argparser = prepare_argparser() + self.args = argparser.parse_args(["Chr1.gtf", ""]) + self.db_path = os.path.join(TEST_DIR, "Chr1.db") + gffutils.create_db(os.path.join(TEST_DIR, self.args.GFF_IN), self.db_path, force=True) + + def tearDown(self): + os.remove(self.db_path) + + def test_matching_chr(self): + mock_af = MagicMock() + mock_af.fetch.return_value = object + with patch("peaks2utr.validation.index_bam_file") as mock_index: + with patch("pysam.AlignmentFile", return_value=mock_af): + self.assertTrue(matching_chr(self.db_path, self.args)) + mock_af.fetch.side_effect = ValueError() + self.assertFalse(matching_chr(self.db_path, self.args)) + self.assertEqual(mock_index.call_count, 2) + + +if __name__ == '__main__': + unittest.main()