Skip to content

Commit

Permalink
allow multiple cuts on same parameter (gwastro#4149)
Browse files Browse the repository at this point in the history
* Amend cuts module to allow multiple cuts on same parameter, and to check for duplicated cuts

* CC

* incorrect function name attribute for numpy ufuncs

* CC

* add some unit tests for the cuts module

* add a couple of tests on the same parameter/type of cut

* dont redo a conversion

* small fixes
  • Loading branch information
GarethCabournDavies authored and acorreia61201 committed Apr 4, 2024
1 parent 6b7393e commit 61de4f2
Show file tree
Hide file tree
Showing 2 changed files with 321 additions and 27 deletions.
101 changes: 74 additions & 27 deletions pycbc/events/cuts.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,12 +109,55 @@ def convert_inputstr(inputstr, choices):
try:
cut_value = float(cut_value_str)
except ValueError as value_e:
logging.warning("Error: Cut value must be convertible into a float, "
"got '%s', a %s.", cut_value, type(cut_value))
logging.warning("ERROR: Cut value must be convertible into a float, "
"got '%s'.", cut_value_str)
raise value_e

return {cut_param: (ineq_functions[cut_limit],
float(cut_value_str))}
return {(cut_param, ineq_functions[cut_limit]): cut_value}


def check_update_cuts(cut_dict, new_cut):
"""
Update a cuts dictionary, but check whether the cut exists already,
warn and only apply the strictest cuts
Parameters
----------
cut_dict: dictionary
Dictionary containing the cuts to be checked, will be updated
new_cut: single-entry dictionary
dictionary to define the new cut which is being considered to add
"""
new_cut_key = list(new_cut.keys())[0]
if new_cut_key in cut_dict:
# The cut has already been called
logging.warning("WARNING: Cut parameter %s and function %s have "
"already been used. Utilising the strictest cut.",
new_cut_key[0], new_cut_key[1].__name__)
# Extract the function and work out which is strictest
cut_function = new_cut_key[1]
value_new = list(new_cut.values())[0]
value_old = cut_dict[new_cut_key]
if cut_function(value_new, value_old):
# The new threshold would survive the cut of the
# old threshold, therefore the new threshold is stricter
# - update it
logging.warning("WARNING: New threshold of %.3f is "
"stricter than old threshold %.3f, "
"using cut at %.3f.",
value_new, value_old, value_new)
cut_dict.update(new_cut)
else:
# New cut would not make a difference, ignore it
logging.warning("WARNING: New threshold of %.3f is less "
"strict than old threshold %.3f, using "
"cut at %.3f.",
value_new, value_old, value_old)
else:
# This is a new cut - add it
cut_dict.update(new_cut)


def ingest_cuts_option_group(args):
Expand All @@ -133,13 +176,14 @@ def ingest_cuts_option_group(args):
# Handle trigger cuts
trigger_cut_dict = {}
for inputstr in trigger_cut_strs:
trigger_cut_dict.update(convert_inputstr(inputstr,
trigger_param_choices))
new_trigger_cut = convert_inputstr(inputstr, trigger_param_choices)
check_update_cuts(trigger_cut_dict, new_trigger_cut)

# Handle template cuts
template_cut_dict = {}
for inputstr in template_cut_strs:
template_cut_dict.update(convert_inputstr(inputstr,
template_param_choices))
new_template_cut = convert_inputstr(inputstr, template_param_choices)
check_update_cuts(template_cut_dict, new_template_cut)

return trigger_cut_dict, template_cut_dict

Expand All @@ -151,14 +195,14 @@ def apply_trigger_cuts(triggers, trigger_cut_dict):
Parameters
----------
triggers: ReadByTemplate object
triggers: ReadByTemplate object or dictionary
The triggers in this particular template. This
must have the correct datasets required to calculate
the values we cut on.
trigger_cut_dict: dictionary
Dictionary with parameters as keys, and tuples of
(cut_function, cut_threshold) as values
Dictionary with tuples of (parameter, cut_function)
as keys, cut_thresholds as values
made using ingest_cuts_option_group function
Returns
Expand All @@ -170,9 +214,9 @@ def apply_trigger_cuts(triggers, trigger_cut_dict):
idx_out = np.arange(len(triggers['snr']))

# Loop through the different cuts, and apply them
for parameter, cut_function_thresh in trigger_cut_dict.items():
for parameter_cut_function, cut_thresh in trigger_cut_dict.items():
# The function and threshold are stored as a tuple so unpack it
cut_function, cut_thresh = cut_function_thresh
parameter, cut_function = parameter_cut_function

# What kind of parameter is it?
if parameter.endswith('_chisq'):
Expand All @@ -182,8 +226,10 @@ def apply_trigger_cuts(triggers, trigger_cut_dict):
value = get_chisq_from_file_choice(triggers, chisq_choice)
# Apply any previous cuts to the value for comparison
value = value[idx_out]
elif parameter in triggers.file[triggers.ifo]:
# parameter can be read direct from the trigger file
elif (parameter in triggers
or (hasattr(triggers, "file")
and parameter in triggers.file[triggers.ifo])):
# parameter can be read direct from the trigger dictionary / file
value = triggers[parameter]
# Apply any previous cuts to the value for comparison
value = value[idx_out]
Expand All @@ -203,7 +249,7 @@ def apply_trigger_cuts(triggers, trigger_cut_dict):
return idx_out


def apply_template_fit_cut(statistic, ifos, parameter, cut_function_thresh,
def apply_template_fit_cut(statistic, ifos, parameter_cut_function, cut_thresh,
template_ids):
"""
Apply cuts to template fit parameters, these have a few more checks
Expand All @@ -221,11 +267,12 @@ def apply_template_fit_cut(statistic, ifos, parameter, cut_function_thresh,
List of IFOS used in this findtrigs instance.
Templates must pass cuts in all IFOs.
parameter: string
Which parameter is being used for the cut?
parameter_cut_function: thresh
First entry: Which parameter is being used for the cut?
Second entry: Cut function
cut_function_thresh: tuple
tuple of the cut function and cut threshold
cut_thresh: float or int
Cut threshold to the parameter according to the cut function
template_ids: numpy array
Array of template_ids which have passed previous cuts
Expand All @@ -236,7 +283,7 @@ def apply_template_fit_cut(statistic, ifos, parameter, cut_function_thresh,
tids_out: numpy array
Array of template_ids which have passed this cut
"""
cut_function, cut_thresh = cut_function_thresh
parameter, cut_function = parameter_cut_function
statistic_classname = statistic.__class__.__name__

# We can only apply template fit cuts if template fits have been done
Expand Down Expand Up @@ -279,8 +326,8 @@ def apply_template_cuts(bank, template_cut_dict, template_ids=None,
Must contain the usual template bank datasets
template_cut_dict: dictionary
Dictionary with parameters as keys, and tuples of
(cut_function, cut_threshold) as values
Dictionary with tuples of (parameter, cut_function)
as keys, cut_thresholds as values
made using ingest_cuts_option_group function
Optional Parameters
Expand Down Expand Up @@ -321,9 +368,9 @@ def apply_template_cuts(bank, template_cut_dict, template_ids=None,
return tids_out

# Loop through the different cuts, and apply them
for parameter, cut_function_thresh in template_cut_dict.items():
for parameter_cut_function, cut_thresh in template_cut_dict.items():
# The function and threshold are stored as a tuple so unpack it
cut_function, cut_thresh = cut_function_thresh
parameter, cut_function = parameter_cut_function

if parameter in bank_conv.conversion_options:
# Calculate the parameter values using the bank property helper
Expand All @@ -334,8 +381,8 @@ def apply_template_cuts(bank, template_cut_dict, template_ids=None,
if statistic and ifos:
tids_out = apply_template_fit_cut(statistic,
ifos,
parameter,
cut_function_thresh,
parameter_cut_function,
cut_thresh,
tids_out)
else:
raise ValueError("Cut parameter " + parameter + " not recognised."
Expand Down
Loading

0 comments on commit 61de4f2

Please sign in to comment.