diff --git a/tardis/visualization/plot_util.py b/tardis/visualization/plot_util.py index 4880ae94c71..1c15fd93cf6 100644 --- a/tardis/visualization/plot_util.py +++ b/tardis/visualization/plot_util.py @@ -163,77 +163,65 @@ def parse_species_list_util(species_list): Species can be given as an ion (e.g. Si II), an element (e.g. Si), a range of ions (e.g. Si I - V), or any combination of these (e.g. species_list = [Si II, Fe I-V, Ca]) + Returns + ------- + dict + A dictionary containing: + - full_species_list: List of expanded species (e.g. Si I - V -> [Si I, Si II, ...]). + - species_mapped: Mapping of species ids to species names. + - keep_colour: List of atomic numbers to group elements with consistent colors. """ - if species_list is not None: - # check if there are any digits in the species list. If there are, then exit. - # species_list should only contain species in the Roman numeral - # format, e.g. Si II, and each ion must contain a space - if any(char.isdigit() for char in " ".join(species_list)) is True: - raise ValueError( - "All species must be in Roman numeral form, e.g. Si II" + if species_list is None: + return { + "species_mapped": None, + "keep_colour": None, + "species_list": None, + } + + + if any(char.isdigit() for char in " ".join(species_list)): + raise ValueError("All species must be in Roman numeral form, e.g., Si II") + + full_species_list = [] + species_mapped = {} + keep_colour = [] + + for species in species_list: + if "-" in species: + element = species.split(" ")[0] + first_ion_numeral = roman_to_int(species.split(" ")[-1].split("-")[0]) + second_ion_numeral = roman_to_int(species.split(" ")[-1].split("-")[-1]) + + for ion_number in range(first_ion_numeral, second_ion_numeral + 1): + full_species_list.append(f"{element} {int_to_roman(ion_number)}") + else: + full_species_list.append(species) + + requested_species_ids = [] + + for species in full_species_list: + if " " in species: + species_id = ( + species_string_to_tuple(species)[0] * 100 + + species_string_to_tuple(species)[1] ) + requested_species_ids.append([species_id]) + species_mapped[species_id] = [species_id] else: - full_species_list = [] - species_mapped = {} - for species in species_list: - # check if a hyphen is present. If it is, then it indicates a - # range of ions. Add each ion in that range to the list as a new entry - if "-" in species: - # split the string on spaces. First thing in the list is then the element - element = species.split(" ")[0] - # Next thing is the ion range - # convert the requested ions into numerals - first_ion_numeral = roman_to_int( - species.split(" ")[-1].split("-")[0] - ) - second_ion_numeral = roman_to_int( - species.split(" ")[-1].split("-")[-1] - ) - # add each ion between the two requested into the species list - for ion_number in np.arange( - first_ion_numeral, second_ion_numeral + 1 - ): - full_species_list.append( - f"{element} {int_to_roman(ion_number)}" - ) - else: - # Otherwise it's either an element or ion so just add to the list - full_species_list.append(species) - - # full_species_list is now a list containing each individual species requested - # e.g. it parses species_list = [Si I - V] into species_list = [Si I, Si II, Si III, Si IV, Si V] - - requested_species_ids = [] - keep_colour = [] - - # go through each of the requested species. Check whether it is - # an element or ion (ions have spaces). If it is an element, - # add all possible ions to the ions list. Otherwise just add - # the requested ion - for species in full_species_list: - if " " in species: - species_id = ( - species_string_to_tuple(species)[0] * 100 - + species_string_to_tuple(species)[1] - ) - requested_species_ids.append([species_id]) - species_mapped[species_id] = [species_id] - else: - atomic_number = element_symbol2atomic_number(species) - species_ids = [ - atomic_number * 100 + ion_number - for ion_number in np.arange(atomic_number) - ] - requested_species_ids.append(species_ids) - species_mapped[atomic_number * 100] = species_ids - # add the atomic number to a list so you know that this element should - # have all species in the same colour, i.e. it was requested like - # species_list = [Si] - keep_colour.append(atomic_number) - requested_species_ids = [ - species_id - for temp_list in requested_species_ids - for species_id in temp_list + atomic_number = element_symbol2atomic_number(species) + species_ids = [ + atomic_number * 100 + ion_number for ion_number in range(atomic_number) ] + requested_species_ids.append(species_ids) + species_mapped[atomic_number * 100] = species_ids + keep_colour.append(atomic_number) + + requested_species_ids = [ + species_id for temp_list in requested_species_ids for species_id in temp_list + ] - return requested_species_ids, species_mapped, keep_colour, full_species_list + return { + "species_mapped": species_mapped, + "keep_colour": keep_colour, + "species_list": requested_species_ids, + } diff --git a/tardis/visualization/tools/liv_plot.py b/tardis/visualization/tools/liv_plot.py index 73815a493d4..616a53281d9 100644 --- a/tardis/visualization/tools/liv_plot.py +++ b/tardis/visualization/tools/liv_plot.py @@ -88,12 +88,13 @@ def _parse_species_list(self, species_list, packets_mode, nelements=None): If species list contains invalid entries. """ - ( - self._species_list, - self._species_mapped, - self._keep_colour, - self._full_species_list, - ) = pu.parse_species_list_util(species_list) + parsed_species_data = pu.parse_species_list_util(species_list) + if parsed_species_data is None: + self._species_list = None + else: + self._species_mapped = parsed_species_data["species_mapped"] + self._keep_colour = parsed_species_data["keep_colour"] + self._species_list = parsed_species_data["species_list"] if nelements: interaction_counts = ( diff --git a/tardis/visualization/tools/sdec_plot.py b/tardis/visualization/tools/sdec_plot.py index 17ec7e348a7..bfbef62d4ef 100644 --- a/tardis/visualization/tools/sdec_plot.py +++ b/tardis/visualization/tools/sdec_plot.py @@ -92,12 +92,13 @@ def _parse_species_list(self, species_list): (e.g. Si I - V), or any combination of these (e.g. species_list = [Si II, Fe I-V, Ca]) """ - ( - self._species_list, - self._species_mapped, - self._keep_colour, - self._full_species_list, - ) = pu.parse_species_list_util(species_list) + parsed_species_data = pu.parse_species_list_util(species_list) + if parsed_species_data is None: + self._species_list = None + else: + self._species_mapped = parsed_species_data["species_mapped"] + self._keep_colour = parsed_species_data["keep_colour"] + self._species_list = parsed_species_data["species_list"] def _calculate_plotting_data( self, packets_mode, packet_wvl_range, distance, nelements