Skip to content

Commit

Permalink
Merge pull request #3 from HudsonAlpha/dnascope
Browse files Browse the repository at this point in the history
Dnascope
  • Loading branch information
holtjma authored Jul 26, 2022
2 parents 8261d3e + 9f8d190 commit b05858d
Show file tree
Hide file tree
Showing 14 changed files with 567 additions and 189 deletions.
9 changes: 7 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,12 @@ We have currently configured the pipeline to run using an LSF cluster. If a dif
2. If running locally, configuration changes to the snakemake command itself may be necessary. These are located in variable `snakemakeFrags`, contact the repo owner or submit a ticket for assistance.

### Training Configuration (Optional)
There are several options available that adjust way training of models is performed (e.g. models used, hyperparameters, training method, etc.). These options are available in `TrainingConfig.py`. Details on each will be in a future release.
There are several options available that adjust way training of models is performed (e.g. models used, hyperparameters, training method, etc.).
These options are available in `TrainingConfig.py` and generally described in-line with the parameter.
However, some are of critical importance and should be considered for training:

1. `ENABLE_AUTO_TARGET` - If set to `True`, this option enables an alternate method for determining the target recall that is automatically calculated based on the observed precision and the desired global precision (another parameter `GLOBAL_AUTO_TARGET_PRECISION`). *Global precision* in this context is the combined precision of the upstream pipeline followed by STEVE identifying false positive calls from that pipeline. For example, if the desired global precision is 99% and the upstream pipeline achieves 98% precision by itself, then the models trained by STEVE need to capture 1% out of the 2% false positives to achieve 99% global precision. This means the target recall will be set to 0.5000 indicating that half of all false positives need to be identified to achieve the desired global precision. This approach allows for pipeline that _already_ have a high precision to have lower/easier targets in STEVE to achieve the same global precision. In practice, this allows you to swap the upstream pipelines without needing to recalculate/adjust the target/accepted recall values to account for a more or less precise upstream pipeline.
2. `ENABLE_FEATURE_SELECTION` - If set to `True`, this option enabled feature selection prior to model training. Numerous parameters adjust how exactly the feature selection is performed. In general, enabling feature selection leads to longer training times, but may improve the results by removing unnecessary and/or redundant features using a systematic approach.

### Setting up conda environment
Assuming conda or miniconda is installed, use the following two commands to create and then activate a conda environment for the pipeline:
Expand All @@ -53,7 +58,7 @@ conda activate steve
Assuming the above configuration steps were successful, all that remains is to run the pipeline itself. Here is an example execution of the pipeline used in the paper:

```
python3 scripts/RunTrainingPipeline.py -dest -x ./GIAB_v1.json
python3 scripts/RunTrainingPipeline.py -dest -x ./sample_metadata/GIAB_v1.json
```

For details on each option, run `python3 scripts/RunTrainingPipeline.py -h`. The following sections describe what the pipeline is actually doing.
Expand Down
149 changes: 107 additions & 42 deletions scripts/EvaluateVariants.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,52 @@

import argparse as ap
import cyvcf2
import csv
import json
import numpy as np
import pickle
import re
import vcf
#import vcf

from ExtractFeatures import ALL_METRICS, getVariantFeatures, GT_TRANSLATE, VAR_TRANSLATE
from TrainingConfig import ENABLE_AUTO_TARGET, GLOBAL_AUTO_TARGET_PRECISION

def getClinicalModel(stats, acceptedRecall, targetRecall):
def getClinicalModel(stats, acceptedRecall, targetRecall, global_precision):
'''
@param stats - the stats dictionary from training
@param acceptedRecall - the minimum acceptable recall
@param targetRecall - the target we are aiming for
@return - a dictionary with the eval recall, model name, and harmonic mean score
@param targetRecall - the target we are aiming for, must be greater than or equal to accepted
@param global_precision - float value indicating target global precision; if set,
then it will dynamically figure out the target recalls based on the data
@return - a dictionary with the eval recall, model name, and harmonic mean score; more info is added if global_precision is used
'''
bestModelTargetRecall = None
bestModelName = None
bestHM = 0.0

if global_precision != None:
#we need to override the accepted and target recall values
#print(stats.keys())
lookup_key = list(stats.keys())[0]
base_total_tp = stats[lookup_key]['RAW_TOTAL_TP']
base_total_fp = stats[lookup_key]['RAW_TOTAL_FP']
base_precision = base_total_tp / (base_total_fp + base_total_tp)

#figure out how far from the goal we are
delta_precision = global_precision - base_precision
assert(delta_precision > 0)
remainder_precision = 1.0 - base_precision

#now derive our target recall from that difference
derived_recall = delta_precision / remainder_precision

#for now, just set both accepted and target to that same value
acceptedRecall = derived_recall
targetRecall = derived_recall

#check these after we potentially override any passed in values
assert(targetRecall >= acceptedRecall)

for mn in stats.keys():
for tr in stats[mn]['ALL_SUMMARY']:
#CM = confusion matrix
Expand All @@ -40,21 +68,34 @@ def getClinicalModel(stats, acceptedRecall, targetRecall):
else:
#in clinical, best is harmonic mean of our adjusted recall and our TNR
modelTNR = modelCM[0, 0] / (modelCM[0, 0] + modelCM[0, 1])
adjRecall = (modelRecall - acceptedRecall) / (float(targetRecall) - acceptedRecall)
if adjRecall > 1.0:
adjRecall = 1.0
if targetRecall == acceptedRecall:
#target and accepted are identical, so this is a binary kill switch that's either 1.0 or 0.0
if modelRecall >= acceptedRecall:
adjRecall = 1.0
else:
adjRecall = 0.0
else:
#otherwise, there's a scale
adjRecall = (modelRecall - acceptedRecall) / (targetRecall - acceptedRecall)
if adjRecall > 1.0:
adjRecall = 1.0
modelHM = 2 * adjRecall * modelTNR / (adjRecall+modelTNR)

if modelHM > bestHM:
bestModelTargetRecall = tr
bestModelName = mn
bestHM = modelHM

return {
ret = {
'eval_recall' : bestModelTargetRecall,
'model_name' : bestModelName,
'hm_score' : bestHM
}
if global_precision != None:
ret['base_precision'] = base_precision
ret['derived_recall'] = derived_recall

return ret

def evaluateVariants(args):
'''
Expand Down Expand Up @@ -127,7 +168,7 @@ def runSubType(variantType, args, stats, models, statKey):

#make sure our recall is in the list
availableRecalls = stats[list(stats.keys())[0]]['ALL_SUMMARY'].keys()
if targetRecall not in availableRecalls:
if targetRecall not in availableRecalls and not ENABLE_AUTO_TARGET:
raise Exception('Invalid target recall, available options are: %s' % (availableRecalls, ))

#figure out which models we will actually be using
Expand All @@ -152,17 +193,25 @@ def runSubType(variantType, args, stats, models, statKey):
evalList = [bestModelName]

elif modelName == 'clinical':
#TODO: make this a script parameter instead of hard-coding
#this contains minimum acceptable recals for a given target
targetThresholds = {
'0.995' : 0.99
}
if targetRecall not in targetThresholds:
raise Exception('"clinical" mode has no defined threshold for target recall "%s"' % targetRecall)
acceptedRecall = targetThresholds[targetRecall]
if ENABLE_AUTO_TARGET:
global_precision = GLOBAL_AUTO_TARGET_PRECISION

#dummy values
acceptedRecall = 0.0
targetRecall = 0.0
else:
global_precision = None
#TODO: make this a script parameter instead of hard-coding
#this contains minimum acceptable recals for a given target
targetThresholds = {
'0.995' : 0.99
}
if targetRecall not in targetThresholds:
raise Exception('"clinical" mode has no defined threshold for target recall "%s"' % targetRecall)
acceptedRecall = targetThresholds[targetRecall]

#get the clinical model
clinicalModelDict = getClinicalModel(stats, acceptedRecall, targetRecall)
clinicalModelDict = getClinicalModel(stats, acceptedRecall, targetRecall, global_precision)
bestModelName = clinicalModelDict['model_name']
bestModelTargetRecall = clinicalModelDict['eval_recall']

Expand Down Expand Up @@ -213,23 +262,24 @@ def runSubType(variantType, args, stats, models, statKey):
allVariants += loadCodicemVariants(args.codicem)

#now load the VCF file
vcfReader = vcf.Reader(filename=args.sample_vcf, compressed=True)
rawReader = vcf.Reader(filename=args.sample_vcf, compressed=True)
assert(len(vcfReader.samples) == 1)
chromList = vcfReader.contigs.keys()
vcf_reader = cyvcf2.VCF(args.sample_vcf)
raw_reader = cyvcf2.VCF(args.sample_vcf)
assert(len(vcf_reader.samples) == 1)
sample_index = 0
chrom_list = vcf_reader.seqnames

#go through each variant and extract the features into a shared set
varIndex = []
rawVariants = []
rawGT = []
varFeatures = []
for i, (chrom, start, end, ref, alt) in enumerate(allVariants):
if (chrom not in chromList and
'chr'+chrom in chromList):
if (chrom not in chrom_list and
'chr'+chrom in chrom_list):
chrom = 'chr'+chrom

if chrom in chromList:
variantList = [variant for variant in vcfReader.fetch(chrom, start, end)]
if chrom in chrom_list:
variantList = [variant for variant in vcf_reader(f'{chrom}:{start}-{end}')]
else:
print('WARNING: Chromosome "%s" not found' % (chrom, ))
variantList = []
Expand All @@ -240,7 +290,7 @@ def runSubType(variantType, args, stats, models, statKey):

#now go through each variant and pull out the features for it
for variant in variantList:
featureVals = getVariantFeatures(variant, vcfReader.samples[0], fields, rawReader, allowHomRef=True)
featureVals = getVariantFeatures(variant, sample_index, fields, raw_reader, allowHomRef=True)
varFeatures.append(featureVals)
rawGT.append(featureVals[gtIndex])

Expand Down Expand Up @@ -283,7 +333,8 @@ def runSubType(variantType, args, stats, models, statKey):
#if it's not found, we DON'T put it in the dictionary list
else:
for ind in foundVarIndices:
vals = [chrom, start, end, ref, alt, rawVariants[ind], rawGT[ind]]
raw_variant_str = generateCyvcf2RecordStr(rawVariants[ind])
vals = [chrom, start, end, ref, alt, raw_variant_str, rawGT[ind]]
modelResultDict = {}
resultsFound = False
for mn in evalList:
Expand All @@ -307,7 +358,7 @@ def runSubType(variantType, args, stats, models, statKey):
'end' : end,
'ref' : ref,
'alt' : alt,
'call_variant' : rawVariants[ind],
'call_variant' : raw_variant_str,
'call_gt' : rawGT[ind],
'predictions' : modelResultDict
}
Expand Down Expand Up @@ -347,23 +398,24 @@ def runReferenceCalls(variantType, args, acceptedVT, acceptedGT):
allVariants += loadCodicemVariants(args.codicem)

#now load the VCF file
vcfReader = vcf.Reader(filename=args.sample_vcf, compressed=True)
rawReader = vcf.Reader(filename=args.sample_vcf, compressed=True)
assert(len(vcfReader.samples) == 1)
chromList = vcfReader.contigs.keys()
vcf_reader = cyvcf2.VCF(args.sample_vcf)
raw_reader = cyvcf2.VCF(args.sample_vcf)
assert(len(vcf_reader.samples) == 1)
sample_index = 0
chrom_list = vcf_reader.seqnames

#go through each variant and extract the features into a shared set
varIndex = []
rawVariants = []
rawGT = []
varFeatures = []
for i, (chrom, start, end, ref, alt) in enumerate(allVariants):
if (chrom not in chromList and
'chr'+chrom in chromList):
if (chrom not in chrom_list and
'chr'+chrom in chrom_list):
chrom = 'chr'+chrom

if chrom in chromList:
variantList = [variant for variant in vcfReader.fetch(chrom, start, end)]
if chrom in chrom_list:
variantList = [variant for variant in vcf_reader(f'{chrom}:{start}-{end}')]
else:
print('WARNING: Chromosome "%s" not found' % (chrom, ))
variantList = []
Expand All @@ -374,7 +426,7 @@ def runReferenceCalls(variantType, args, acceptedVT, acceptedGT):

#now go through each variant and pull out the features for it
for variant in variantList:
featureVals = getVariantFeatures(variant, vcfReader.samples[0], fields, rawReader, allowHomRef=True)
featureVals = getVariantFeatures(variant, sample_index, fields, raw_reader, allowHomRef=True)
varFeatures.append(featureVals)
rawGT.append(featureVals[gtIndex])

Expand All @@ -397,7 +449,8 @@ def runReferenceCalls(variantType, args, acceptedVT, acceptedGT):
valList.append(vals)
else:
for ind in foundVarIndices:
vals = [chrom, start, end, ref, alt, rawVariants[ind], rawGT[ind]]
raw_variant_str = generateCyvcf2RecordStr(rawVariants[ind])
vals = [chrom, start, end, ref, alt, raw_variant_str, rawGT[ind]]
modelResultDict = {}
resultsFound = False
if filtersEnabled and (acceptedVT != varFeatures[ind][0] or
Expand All @@ -420,7 +473,7 @@ def runReferenceCalls(variantType, args, acceptedVT, acceptedGT):
'end' : end,
'ref' : ref,
'alt' : alt,
'call_variant' : rawVariants[ind],
'call_variant' : raw_variant_str,
'call_gt' : rawGT[ind],
'predictions' : modelResultDict
}
Expand Down Expand Up @@ -480,6 +533,18 @@ def parseCLIVariants(csVar):
ret.append(var)
return ret

def generateCyvcf2RecordStr(variant):
'''
Simple cyvcf2 variant reformatter, makes it match something similar to vcf
@param variant - a variant from cyvcf2
@return - a string formatted as a "record"
'''
chrom = variant.CHROM
pos = variant.POS
ref = variant.REF
alts = variant.ALT
return f'cyvcf2.record(CHROM={chrom}, POS={pos}, REF={ref}, ALT={alts})'

def loadCodicemVariants(csvFN):
'''
Loads a Codicem sanger CSV and finds variants that need confirmation
Expand Down Expand Up @@ -509,7 +574,7 @@ def loadCodicemVariants(csvFN):
#optional arguments with default
p.add_argument('-c', '--codicem', dest='codicem', default=None, help='a Codicem CSV file with variants to evaluate (default: None)')
p.add_argument('-v', '--variants', dest='variants', default=None, help='variant coordinates to evaluate (default: None)')
p.add_argument('-m', '--model', dest='model', default='best', help='the model name to use (default: best)')
p.add_argument('-m', '--model', dest='model', default='clinical', help='the model name to use (default: clinical)')
p.add_argument('-r', '--recall', dest='recall', default='0.99', help='the target recall value from training (default: 0.99)')
p.add_argument('-o', '--output', dest='outFN', default=None, help='the place to send output to (default: stdout)')

Expand All @@ -520,4 +585,4 @@ def loadCodicemVariants(csvFN):
#parse the arguments
args = p.parse_args()

evaluateVariants(args)
evaluateVariants(args)
Loading

0 comments on commit b05858d

Please sign in to comment.