Skip to content

Commit

Permalink
Merge pull request #27 from haessar/feature-input-validation
Browse files Browse the repository at this point in the history
Validate compatibility of GFF and BAM input files
  • Loading branch information
haessar authored Jan 15, 2024
2 parents e8fae87 + 7a1d5bc commit c1b4c07
Show file tree
Hide file tree
Showing 7 changed files with 103 additions and 16 deletions.
12 changes: 12 additions & 0 deletions peaks2utr/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def prepare_argparser():
parser.add_argument('-f', '-force', '--force', action="store_true", help="overwrite outputs if they exist")
parser.add_argument('-o', '--output', help="output filename. Defaults to <GFF_IN basename>.new.<ext>")
parser.add_argument('--gtf', dest="gtf_out", action="store_true", help="output in GTF format (rather than default GFF3)")
parser.add_argument('--skip-validation', action="store_true", help="skip validation of input files")
parser.add_argument('--keep-cache', action="store_true", help="keep cached files on run completion")
parser.add_argument('--version', action='version', version='%(prog)s {version}'.format(version=version(__package__)))
return parser
Expand Down Expand Up @@ -99,6 +100,7 @@ async def _main(args):
from .utils import cached, yield_from_process
from .preprocess import BAMSplitter, call_peaks, create_db
from .postprocess import merge_annotations, gt_gff3_sort, write_summary_stats
from .validation import matching_chr

try:
###################
Expand Down Expand Up @@ -171,6 +173,16 @@ async def _main(args):
BroadPeaksList(broadpeak_fn=cached("forward_peaks.broadPeak"), strand="forward") + \
BroadPeaksList(broadpeak_fn=cached("reverse_peaks.broadPeak"), strand="reverse")

###################
# Validation #
###################

if not args.skip_validation:
logging.info("Performing input file validation.")
if not matching_chr(db, args):
logging.error("No chromosome shared between GFF_IN and BAM_IN. Aborting.")
sys.exit(1)

###################
# Process peaks #
###################
Expand Down
11 changes: 3 additions & 8 deletions peaks2utr/annotations.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
import logging
import math
import multiprocessing
import sqlite3

from tqdm import tqdm

from . import constants, criteria
from .constants import AnnotationColour, STRAND_MAP
from .collections import SPATTruncationPointsDict, ZeroCoverageIntervalsDict
from .exceptions import AnnotationsError
from .models import UTR, FeatureDB
from .utils import Counter, Falsey, cached, features_dict_for_gene, iter_batches
from .models import UTR
from .utils import Counter, Falsey, cached, connect_db, features_dict_for_gene, iter_batches


class NoNearbyFeatures(Falsey):
Expand Down Expand Up @@ -46,10 +45,6 @@ def __enter__(self):
def __exit__(self, type, value, traceback):
self.pbar.close()

def _connect_db(self):
db = sqlite3.connect(self.db_path, check_same_thread=False)
return FeatureDB(db)

def _batch_annotate_strand(self, peaks_batch):
"""
Create multiprocessing Process to handle batch of peaks. Connect to sqlite3 db for each batch to prevent
Expand All @@ -60,7 +55,7 @@ def _batch_annotate_strand(self, peaks_batch):
for strand, symbol in STRAND_MAP.items():
truncation_points[symbol] = SPATTruncationPointsDict(json_fn=cached(strand + "_unmapped.json"))
coverage_gaps[symbol] = ZeroCoverageIntervalsDict(bed_fn=cached(strand + "_coverage_gaps.bed"))
db = self._connect_db()
db = connect_db(self.db_path)
return multiprocessing.Process(target=self._iter_peaks, args=(db, peaks_batch, truncation_points, coverage_gaps))

def _iter_peaks(self, db, peaks_batch, truncation_points, coverage_gaps):
Expand Down
9 changes: 2 additions & 7 deletions peaks2utr/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from .exceptions import EXCEPTIONS_MAP
from .models import SoftClippedRead
from .utils import cached, consume_lines, filter_nested_dict, sum_nested_dicts, multiprocess_over_dict
from .utils import cached, consume_lines, filter_nested_dict, index_bam_file, sum_nested_dicts, multiprocess_over_dict
from .constants import CACHE_DIR, LOG_DIR, STRAND_PYSAM_ARGS


Expand Down Expand Up @@ -74,16 +74,11 @@ def split_read_groups(self):
for bf in self.read_group_bams}
self.spat_outputs_to_process = self.spat_outputs.copy()

def index_bam_file(self, bam_file):
if not os.path.isfile(cached(bam_file + '.bai')):
logging.info("Indexing %s." % bam_file)
pysam.index("-@", str(self.args.processors), bam_file)

def _get_max_reads_for_pbar(self):
max_reads = 0
for bf in self.read_group_bams:
if not os.path.isfile(self.spat_outputs[bf]):
self.index_bam_file(bf)
index_bam_file(bf, self.args.processors)
idxstats = pysam.idxstats(bf).split('\n')
num_reads = sum([int(chr.split("\t")[2]) + int(chr.split("\t")[3]) for chr in idxstats[:-1]])
if num_reads > max_reads:
Expand Down
16 changes: 16 additions & 0 deletions peaks2utr/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
import logging
import multiprocessing
import os.path
from queue import Empty
import resource
import sqlite3

import pysam

from .constants import CACHE_DIR
from .exceptions import EXCEPTIONS_MAP
from .models import FeatureDB


class Falsey:
Expand Down Expand Up @@ -42,6 +47,17 @@ def cached(filename):
return os.path.join(CACHE_DIR, filename)


def connect_db(db_path):
db = sqlite3.connect(db_path, check_same_thread=False)
return FeatureDB(db)


def index_bam_file(bam_file, processors):
if not os.path.isfile(cached(bam_file + '.bai')):
logging.info("Indexing %s." % bam_file)
pysam.index("-@", str(processors), bam_file)


async def consume_lines(pipe, log_file):
"""
Asynchronously write lines in pipe to log file.
Expand Down
33 changes: 33 additions & 0 deletions peaks2utr/validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import logging

import pysam

from .utils import connect_db, index_bam_file


def matching_chr(db_path, args):
"""
Check seqids in BAM and GFF input files to ensure at least one matches. Returns bool.
"""
db = connect_db(db_path)
gff_chrs = {f.seqid for f in db.all_features()}
bam_chrs = set()

index_bam_file(args.BAM_IN, args.processors)
samfile = pysam.AlignmentFile(args.BAM_IN, "rb", require_index=True)

for chr in gff_chrs:
try:
samfile.fetch(chr)
except ValueError:
logging.warning("Chromosome {} from GFF_IN not found in BAM_IN.".format(chr))
else:
bam_chrs.add(str(chr))
if len(bam_chrs) > 1:
logging.warning(
"""
Chromosomes {} are present in both GFF_IN and BAM_IN.
Consider reducing both to a single chromosome to improve performance.
""".format(', '.join(bam_chrs))
)
return bool(bam_chrs)
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[metadata]
name = peaks2utr
version = 1.2.3
version = 1.2.4
author = William Haese-Hill
author_email = [email protected]
description = A robust, parallelized Python CLI for annotating three_prime_UTR
Expand Down
36 changes: 36 additions & 0 deletions tests/test_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import os
import os.path
import unittest
from unittest.mock import MagicMock, patch

import gffutils

from peaks2utr import prepare_argparser
from peaks2utr.validation import matching_chr

TEST_DIR = os.path.dirname(__file__)


class TestValidation(unittest.TestCase):
def setUp(self):
argparser = prepare_argparser()
self.args = argparser.parse_args(["Chr1.gtf", ""])
self.db_path = os.path.join(TEST_DIR, "Chr1.db")
gffutils.create_db(os.path.join(TEST_DIR, self.args.GFF_IN), self.db_path, force=True)

def tearDown(self):
os.remove(self.db_path)

def test_matching_chr(self):
mock_af = MagicMock()
mock_af.fetch.return_value = object
with patch("peaks2utr.validation.index_bam_file") as mock_index:
with patch("pysam.AlignmentFile", return_value=mock_af):
self.assertTrue(matching_chr(self.db_path, self.args))
mock_af.fetch.side_effect = ValueError()
self.assertFalse(matching_chr(self.db_path, self.args))
self.assertEqual(mock_index.call_count, 2)


if __name__ == '__main__':
unittest.main()

0 comments on commit c1b4c07

Please sign in to comment.