"""
Spectral element DEComposition (SDEC) Plot for TARDIS simulation models.
This plot is a spectral diagnostics plot similar to those originally
proposed by M. Kromer (see, for example, Kromer et al. 2013, figure 4).
"""
import logging
import astropy.units as u
import matplotlib.cm as cm
import matplotlib.colors as clr
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly.graph_objects as go
from astropy.modeling.models import BlackBody
from tardis.util.base import (
atomic_number2element_symbol,
int_to_roman,
)
from tardis.visualization import plot_util as pu
logger = logging.getLogger(__name__)
[docs]
class SDECPlotter:
"""
Plotting interface for Spectral element DEComposition (SDEC) Plot.
It performs necessary calculations to generate SDEC Plot for a simulation
model, and allows to plot it in matplotlib and plotly.
"""
def __init__(self):
"""
Initialize the SDECPlotter with required data of simulation model.
"""
self.packet_data = {
"real": {"packets_df": None, "packets_df_line_interaction": None},
"virtual": {
"packets_df": None,
"packets_df_line_interaction": None,
},
}
self.spectrum = {"virtual": None, "real": None}
self.t_inner = None
self.r_inner = None
self.time_of_simulation = None
self._default_scatter_kwargs = {
"mode": "none",
"hovertemplate": "(%{x:.2f}, %{y:.3g})",
}
self._predefined_traces = {
"emission": {
"noint": {"name": "No interaction", "fillcolor": "#4C4C4C"},
"escatter": {
"name": "Electron Scatter Only",
"fillcolor": "#8F8F8F",
"hoverlabel": {"namelength": -1},
},
"other": {"name": "Other elements", "fillcolor": "#C2C2C2"},
},
"absorption": {
"other": {"name": "Other elements", "fillcolor": "#C2C2C2"},
},
}
[docs]
@classmethod
def from_simulation(cls, sim):
"""
Create an instance of SDECPlotter from a TARDIS simulation object.
Parameters
----------
sim : tardis.simulation.Simulation
TARDIS Simulation object produced by running a simulation
Returns
-------
SDECPlotter
"""
plotter = cls()
plotter.t_inner = sim.simulation_state.packet_source.temperature
plotter.r_inner = sim.simulation_state.geometry.r_inner_active
plotter.time_of_simulation = (
sim.transport.transport_state.packet_collection.time_of_simulation
* u.s
)
modes = ["real"]
if sim.transport.transport_state.virt_logging:
modes.append("virtual")
for mode in modes:
plotter.spectrum[mode] = pu.get_spectrum_data(mode, sim)
plotter.packet_data[mode] = pu.extract_and_process_packet_data(
sim, mode
)
return plotter
[docs]
@classmethod
def from_hdf(cls, hdf_fpath):
"""
Create an instance of SDECPlotter from a simulation HDF file.
Parameters
----------
hdf_fpath : str
Valid path to the HDF file where simulation is saved
packets_mode : {'virtual', 'real'}, optional
Mode of packets to be considered (default: 'virtual')
Returns
-------
SDECPlotter
"""
plotter = cls()
with pd.HDFStore(hdf_fpath, "r") as hdf:
plotter.r_inner = u.Quantity(
hdf["/simulation/simulation_state/r_inner"].to_numpy(), "cm"
)
plotter.t_inner = u.Quantity(
hdf["/simulation/simulation_state/scalars"].t_inner, "K"
)
transport_state_scalars = hdf["/simulation/transport/transport_state/scalars"]
plotter.time_of_simulation = u.Quantity(
transport_state_scalars.time_of_simulation,
"s",
)
has_virtual = bool(getattr(transport_state_scalars, "virt_logging", False))
modes = ["real"] + (["virtual"] if has_virtual else [])
for mode in modes:
plotter.spectrum[mode] = pu.extract_spectrum_data_hdf(hdf, mode)
plotter.packet_data[mode] = pu.extract_and_process_packet_data_hdf(hdf, mode)
return plotter
def _parse_species_list(self, species_list):
"""
Parse user requested species list and create list of species ids to be used.
Parameters
----------
species_list : list of species to plot
List of species (e.g. Si II, Ca II, etc.) that the user wants to show as unique colours.
Species can be given as an ion (e.g. Si II), an element (e.g. Si), a range of ions
(e.g. Si I - V), or any combination of these (e.g. species_list = [Si II, Fe I-V, Ca])
"""
if species_list is not None:
(
species_mapped_tuples,
requested_species_ids_tuples,
keep_colour,
full_species_list,
) = pu.parse_species_list_util(species_list)
self._full_species_list = full_species_list
self._species_list = [
atomic_num * 100 + ion_num
for atomic_num, ion_num in requested_species_ids_tuples
]
self._species_mapped = {
(k[0] * 100 + k[1]): [v[0] * 100 + v[1] for v in values]
for k, values in species_mapped_tuples.items()
}
self._keep_colour = keep_colour
else:
self._full_species_list = None
self._species_list = None
self._species_mapped = None
self._keep_colour = None
def _calculate_plotting_data(
self, packets_mode, packet_wvl_range, distance, nelements
):
"""
Calculate data to be used in plotting based on parameters passed.
Parameters
----------
packets_mode : {'virtual', 'real'}
Mode of packets to be considered, either real or virtual
packet_wvl_range : astropy.Quantity
Wavelength range to restrict the analysis of escaped packets. It
should be a quantity having units of Angstrom, containing two
values - lower lambda and upper lambda i.e.
[lower_lambda, upper_lambda] * u.AA
distance : astropy.Quantity
Distance used to calculate flux instead of luminosity in the plot.
It should have a length unit like m, Mpc, etc.
nelements: int
Number of elements to include in plot. Determined by the
largest contribution to the total luminosity absorbed and emitted.
Other elements are shown in silver. Default value is
None, which displays all elements
Notes
-----
It doesn't return the calculated properties but save them in instance
itself. So it should be always called before starting plotting to
update the plotting data based on parameters passed.
"""
if packets_mode not in ["virtual", "real"]:
raise ValueError(
"Invalid value passed to packets_mode. Only "
"allowed values are 'virtual' or 'real'"
)
if (
packets_mode == "virtual"
and self.packet_data[packets_mode]["packets_df"] is None
):
raise ValueError(
"SDECPlotter doesn't have any data for virtual packets population and SDEC "
"plot for the same was requested. Either set virtual_packet_logging: True "
"in your configuration file to generate SDEC plot with virtual packets, or "
"pass packets_mode='real' in your function call to generate SDEC plot with "
"real packets."
)
# Store the plottable range of each spectrum property which is
# same as entire range, initially
self.plot_frequency_bins = self.spectrum[packets_mode][
"spectrum_frequency_bins"
]
self.plot_wavelength = self.spectrum[packets_mode][
"spectrum_wavelength"
]
self.plot_frequency = self.spectrum[packets_mode][
"spectrum_frequency_bins"
][:-1]
self.packet_wvl_range_mask = np.ones(
self.plot_wavelength.size, dtype=bool
) # default value
# Filter their plottable range based on packet_wvl_range specified
if packet_wvl_range is not None:
packet_nu_range = packet_wvl_range.to("Hz", u.spectral())
# Index of value just before the 1st value that is > packet_nu_range[1]
start_idx = (
np.argmax(self.plot_frequency_bins > packet_nu_range[1]) - 1
)
# Index of value just after the last value that is < packet_nu_range[0]
end_idx = np.argmin(self.plot_frequency_bins < packet_nu_range[0])
self.plot_frequency_bins = self.plot_frequency_bins[
start_idx : end_idx + 1
]
# Since spectrum frequency (& hence wavelength) were created from
# frequency_bins[:-1], so we exclude end_idx when creating the mask
self.packet_wvl_range_mask = np.zeros(
self.plot_wavelength.size, dtype=bool
)
self.packet_wvl_range_mask[start_idx:end_idx] = True
self.plot_wavelength = self.plot_wavelength[
self.packet_wvl_range_mask
]
self.plot_frequency = self.plot_frequency[
self.packet_wvl_range_mask
]
# Make sure number of bin edges are always one more than wavelengths
assert self.plot_frequency_bins.size == self.plot_wavelength.size + 1
# Calculate the area term to convert luminosity to flux
self.lum_to_flux = 1 # default to 1 if distance is none so that this term will have no effect
if distance is not None:
if distance <= 0:
raise ValueError(
"distance passed must be greater than 0. If you intended "
"to plot luminosities instead of flux, set distance=None "
"or don't specify distance parameter in the function call."
)
self.lum_to_flux = 4.0 * np.pi * (distance.to("cm")) ** 2
# Calculate luminosities to be shown in plot
(
self.emission_luminosities_df,
self.emission_species,
) = self._calculate_emission_luminosities(
packets_mode=packets_mode, packet_wvl_range=packet_wvl_range
)
(
self.absorption_luminosities_df,
self.absorption_species,
) = self._calculate_absorption_luminosities(
packets_mode=packets_mode, packet_wvl_range=packet_wvl_range
)
# Calculate the total contribution of elements
# by summing absorption and emission
# Only care about elements, so drop no interaction and electron scattering
# contributions from the emitted luminosities
self.total_luminosities_df = (
self.absorption_luminosities_df
+ self.emission_luminosities_df.drop(["noint", "escatter"], axis=1)
)
# Sort the element list based on the total contribution
sorted_list = self.total_luminosities_df.sum().sort_values(
ascending=False
)
if nelements is None and self._species_list is None:
self.species = np.array(list(self.total_luminosities_df.columns))
elif self._species_list is not None:
sorted_keys = list(sorted_list.keys())
keys_to_keep = [
key for key in sorted_keys if key in self._species_list
]
df_map = {
"total_luminosities_df": keys_to_keep,
"emission_luminosities_df": keys_to_keep
+ ["noint", "escatter"],
"absorption_luminosities_df": keys_to_keep,
}
for df_name, keys in df_map.items():
current_df = getattr(self, df_name)
columns_to_exclude = [
col for col in current_df.columns if col not in keys
]
other_pos = 2 if df_name != "total_luminosities_df" else 0
processed_df = self.process_luminosity_dataframe(
current_df,
columns_to_exclude,
other_column_position=other_pos,
)
setattr(self, df_name, processed_df)
self.species = np.sort(self.total_luminosities_df.columns[1:])
else: # nelements is not None
top_n_keys = sorted_list.keys()[:nelements]
always_keep = ["noint", "escatter"]
df_map = {
"total_luminosities_df": list(top_n_keys),
"emission_luminosities_df": list(top_n_keys) + always_keep,
"absorption_luminosities_df": list(top_n_keys),
}
for df_name, keys_to_keep in df_map.items():
current_df = getattr(self, df_name)
columns_to_exclude = [
col for col in current_df.columns if col not in keys_to_keep
]
other_pos = 2 if df_name != "total_luminosities_df" else 0
processed_df = self.process_luminosity_dataframe(
current_df,
columns_to_exclude,
other_column_position=other_pos,
)
setattr(self, df_name, processed_df)
self.species = np.sort(self.total_luminosities_df.columns[1:])
# Final calculations
self.photosphere_luminosity = self._calculate_photosphere_luminosity()
self.modeled_spectrum_luminosity = (
self.spectrum[packets_mode]["spectrum_luminosity_density_lambda"][
self.packet_wvl_range_mask
]
/ self.lum_to_flux
)
[docs]
def process_luminosity_dataframe(
self, df, keys_to_exclude, other_column_position=0
):
"""
Process a luminosity DataFrame by aggregating specified columns into an 'other' column.
Parameters
----------
df : pandas.DataFrame
The DataFrame containing luminosity data to be processed.
keys_to_exclude : list of str
Column names in `df` whose data should be summed into the 'other' column and removed.
other_column_position : int, optional
The integer location (0-indexed) at which to insert the new 'other' column. Defaults to 0.
Returns
-------
pandas.DataFrame
A new DataFrame with the excluded columns summed into 'other' and removed from the original.
"""
mask = np.isin(df.columns, keys_to_exclude)
excluded_keys = df.columns[mask]
if len(excluded_keys) > 0:
df.insert(
loc=other_column_position,
column="other",
value=df[excluded_keys].sum(axis=1),
)
df = df.drop(columns=excluded_keys)
return df
def _calculate_grouped_luminosities(
self, packets_mode, mask, nu_column, luminosities_df
):
"""
Calculate luminosities for element interactions and populate DataFrame.
Parameters
----------
packets_mode : str
'virtual' or 'real' packets mode
mask : np.array
Boolean mask for wavelength filtering
nu_column : str
Column name containing frequency values
luminosities_df : pd.DataFrame
DataFrame to store results
Returns
-------
tuple
(updated luminosities_df, array of species identifiers)
"""
# Group packets_df by atomic number of elements with which packets
# had their last emission (interaction out)
# or if species_list is requested then group by species id
groupby_column = (
"last_line_interaction_atom"
if self._species_list is None
else "last_line_interaction_species"
)
grouped = (
self.packet_data[packets_mode]["packets_df_line_interaction"]
.loc[mask]
.groupby(by=groupby_column)
)
for identifier, group in grouped:
weights = (
group["energies"] / self.lum_to_flux / self.time_of_simulation
)
hist = np.histogram(
group[nu_column],
bins=self.plot_frequency_bins.value,
weights=weights,
density=False,
)
L_nu = (
hist[0]
* u.erg
/ u.s
/ self.spectrum[packets_mode]["spectrum_delta_frequency"]
)
luminosities_df[identifier] = (
L_nu * self.plot_frequency / self.plot_wavelength
).value
return luminosities_df, np.array(list(grouped.groups.keys()))
def _calculate_luminosity_contribution(
self, packets_mode, mask, contribution_name, luminosities_df
):
"""Calculate luminosity contribution for packets matching the specified mask."""
# Histogram weights are packet luminosities or flux
weights = (
self.packet_data[packets_mode]["packets_df"]["energies"][
self.packet_nu_range_mask
]
/ self.lum_to_flux
) / self.time_of_simulation
hist = np.histogram(
self.packet_data[packets_mode]["packets_df"]["nus"][
self.packet_nu_range_mask
][mask],
bins=self.plot_frequency_bins.value,
weights=weights[mask],
density=False,
)
L_nu = (
hist[0]
* u.erg
/ u.s
/ self.spectrum[packets_mode]["spectrum_delta_frequency"]
)
L_lambda = L_nu * self.plot_frequency / self.plot_wavelength
luminosities_df[contribution_name] = L_lambda.value
def _plot_traces(self, df, group_name, predefined_traces, invert_y=False):
"""Generic helper to plot traces."""
# By specifying a common stackgroup, plotly will itself add up
# luminosities, in order, to created stacked area chart
base_kwargs = {
"x": df.index,
"stackgroup": group_name,
**self._default_scatter_kwargs,
}
# Plot predefined traces first
for colname, trace_info in predefined_traces.items():
if colname in df.columns:
y_data = df[colname] * (-1 if invert_y else 1)
self.fig.add_trace(
go.Scatter(
y=y_data,
**trace_info,
**base_kwargs,
)
)
# Plot species-specific traces
for (species_counter, identifier), species_name in zip(
enumerate(self.species), self._species_name
):
try:
y_data = df[identifier] * (-1 if invert_y else 1)
self.fig.add_trace(
go.Scatter(
y=y_data,
name=f"{species_name} {'Absorption' if invert_y else 'Emission'}",
fillcolor=pu.to_rgb255_string(
self._color_list[species_counter]
),
hoverlabel={"namelength": -1},
**base_kwargs,
)
)
except KeyError:
self._log_missing_species(
identifier, "absorbed" if invert_y else "emitted"
)
def _calculate_emission_luminosities(self, packets_mode, packet_wvl_range):
"""
Calculate luminosities for the emission part of SDEC plot.
Parameters
----------
packets_mode : {'virtual', 'real'}
Mode of packets to be considered, either real or virtual
packet_wvl_range : astropy.Quantity
Wavelength range to restrict the analysis of escaped packets. It
should be a quantity having units of Angstrom, containing two
values - lower lambda and upper lambda i.e.
[lower_lambda, upper_lambda] * u.AA
Returns
-------
luminosities_df : pd.DataFrame
Dataframe containing luminosities contributed by no interaction,
only e-scattering and emission with each element present
elements_present: np.array
Atomic numbers of the elements with which packets of specified
wavelength range interacted
"""
self.packet_nu_range_mask = pu.create_wavelength_mask(
self.packet_data,
packets_mode,
packet_wvl_range,
df_key="packets_df",
column_name="nus",
)
self.packet_nu_line_range_mask = pu.create_wavelength_mask(
self.packet_data,
packets_mode,
packet_wvl_range,
df_key="packets_df_line_interaction",
column_name="nus",
)
luminosities_df = pd.DataFrame(index=self.plot_wavelength)
# Contribution of packets which experienced no interaction
# Mask to select packets with no interaction
mask_noint = (
self.packet_data[packets_mode]["packets_df"][
"last_interaction_type"
][self.packet_nu_range_mask]
== -1
)
self._calculate_luminosity_contribution(
packets_mode, mask_noint, "noint", luminosities_df
)
# Contribution of packets which only experienced electron scattering ---
mask_escatter = (
self.packet_data[packets_mode]["packets_df"][
"last_interaction_type"
][self.packet_nu_range_mask]
== 1
) & (
self.packet_data[packets_mode]["packets_df"][
"last_line_interaction_in_id"
][self.packet_nu_range_mask]
== -1
)
self._calculate_luminosity_contribution(
packets_mode, mask_escatter, "escatter", luminosities_df
)
return self._calculate_grouped_luminosities(
packets_mode=packets_mode,
mask=self.packet_nu_line_range_mask,
nu_column="nus",
luminosities_df=luminosities_df,
)
def _calculate_absorption_luminosities(
self, packets_mode, packet_wvl_range
):
"""
Calculate luminosities for the absorption part of SDEC plot.
Parameters
----------
packets_mode : {'virtual', 'real'}
Mode of packets to be considered, either real or virtual
packet_wvl_range : astropy.Quantity
Wavelength range to restrict the analysis of escaped packets. It
should be a quantity having units of Angstrom, containing two
values - lower lambda and upper lambda i.e.
[lower_lambda, upper_lambda] * u.AA
Returns
-------
pd.DataFrame
Dataframe containing luminosities contributed by absorption with
each element present
"""
self.packet_nu_line_range_mask = pu.create_wavelength_mask(
self.packet_data,
packets_mode,
packet_wvl_range,
df_key="packets_df_line_interaction",
column_name="last_line_interaction_in_nu",
)
luminosities_df = pd.DataFrame(index=self.plot_wavelength)
return self._calculate_grouped_luminosities(
packets_mode=packets_mode,
mask=self.packet_nu_line_range_mask,
nu_column="last_line_interaction_in_nu",
luminosities_df=luminosities_df,
)
def _calculate_photosphere_luminosity(self):
"""
Calculate blackbody luminosity of the photosphere.
Returns
-------
astropy.Quantity
Luminosity density lambda (or Flux) of photosphere (inner boundary
of TARDIS simulation)
"""
bb_lam = BlackBody(
self.t_inner,
scale=1.0 * u.erg / (u.cm**2 * u.AA * u.s * u.sr),
)
L_lambda_ph = (
bb_lam(self.plot_wavelength)
* 4
* np.pi**2
* self.r_inner[0] ** 2
* u.sr
).to("erg / (AA s)")
return L_lambda_ph / self.lum_to_flux
[docs]
def generate_plot_mpl(
self,
packets_mode="virtual",
packet_wvl_range=None,
distance=None,
observed_spectrum=None,
show_modeled_spectrum=True,
ax=None,
figsize=(12, 7),
cmapname="jet",
nelements=None,
species_list=None,
blackbody_photosphere=True,
):
"""
Generate Spectral element DEComposition (SDEC) Plot using matplotlib.
Parameters
----------
packets_mode : {'virtual', 'real'}, optional
Mode of packets to be considered, either real or virtual. Default
value is 'virtual'
packet_wvl_range : astropy.Quantity or None, optional
Wavelength range to restrict the analysis of escaped packets. It
should be a quantity having units of Angstrom, containing two
values - lower lambda and upper lambda i.e.
[lower_lambda, upper_lambda] * u.AA. Default value is None
distance : astropy.Quantity or None, optional
Distance used to calculate flux instead of luminosity in the plot.
It should have a length unit like m, Mpc, etc. Default value is None
observed_spectrum : tuple or list of astropy.Quantity, optional
Option to plot an observed spectrum in the SDEC plot. If given, the first element
should be the wavelength and the second element should be flux,
i.e. (wavelength, flux). The assumed units for wavelength and flux are
angstroms and erg/(angstroms * s * cm^2), respectively. Default value is None.
show_modeled_spectrum : bool, optional
Whether to show modeled spectrum in SDEC Plot. Default value is
True
ax : matplotlib.axes._subplots.AxesSubplot or None, optional
Axis on which to create plot. Default value is None which will
create plot on a new figure's axis.
figsize : tuple, optional
Size of the matplotlib figure to display. Default value is (12, 7)
cmapname : str, optional
Name of matplotlib colormap to be used for showing elements.
Default value is "jet"
nelements: int
Number of elements to include in plot. Determined by the
largest contribution to total luminosity absorbed and emitted.
Other elements are shown in silver. Default value is
None, which displays all elements
species_list: list of strings or None
list of strings containing the names of species that should be included in the SDEC plots.
Must be given in Roman numeral format. Can include specific ions, a range of ions,
individual elements, or any combination of these:
e.g. ['Si II', 'Ca II', 'C', 'Fe I-V']
blackbody_photosphere: bool
Whether to include the blackbody photosphere in the plot. Default value is True
Returns
-------
matplotlib.axes._subplots.AxesSubplot
Axis on which SDEC Plot is created
"""
# If species_list and nelements requested, tell user that nelements is ignored
if species_list is not None and nelements is not None:
logger.info(
"Both nelements and species_list were requested. Species_list takes priority; nelements is ignored"
)
# Parse the requested species list
self._parse_species_list(species_list=species_list)
# Calculate data attributes required for plotting
# and save them in instance itself
self._calculate_plotting_data(
packets_mode=packets_mode,
packet_wvl_range=packet_wvl_range,
distance=distance,
nelements=nelements,
)
if ax is None:
self.ax = plt.figure(figsize=figsize).add_subplot(111)
else:
self.ax = ax
# Get the labels in the color bar. This determines the number of unique colors
self._make_colorbar_labels()
# Set colormap to be used in elements of emission and absorption plots
self.cmap = plt.get_cmap(cmapname, len(self._species_name))
# Get the number of unqie colors
self._make_colorbar_colors()
self._show_colorbar_mpl()
# Plot emission and absorption components
self._plot_emission_mpl()
self._plot_absorption_mpl()
# Plot modeled spectrum
if show_modeled_spectrum:
self.ax.plot(
self.plot_wavelength.value,
self.modeled_spectrum_luminosity.value,
"--b",
label=f"{packets_mode.capitalize()} Spectrum",
linewidth=1,
)
# Plot observed spectrum
if observed_spectrum:
if distance is None:
raise ValueError(
"""
Distance must be specified if an observed_spectrum is given
so that the model spectrum can be converted into flux space correctly.
"""
)
# Convert to wavelength and luminosity units
observed_spectrum_wavelength = observed_spectrum[0].to(u.AA)
observed_spectrum_flux = observed_spectrum[1].to("erg/(s cm**2 AA)")
self.ax.plot(
observed_spectrum_wavelength.value,
observed_spectrum_flux.value,
"-k",
label="Observed Spectrum",
linewidth=1,
)
# Plot photosphere
if blackbody_photosphere:
self.ax.plot(
self.plot_wavelength.value,
self.photosphere_luminosity.value,
"--r",
label="Blackbody Photosphere",
)
# Set legends and labels
xlabel = pu.axis_label_in_latex("Wavelength", u.AA)
if distance is not None: # Set y-axis label for flux
ylabel = pu.axis_label_in_latex(
"F_{\\lambda}", u.Unit("erg/(s cm**2 AA)"), only_text=False
)
else: # Set y-axis label for luminosity
ylabel = pu.axis_label_in_latex(
"L_{\\lambda}", u.Unit("erg/(s AA)"), only_text=False
)
self.ax.legend(fontsize=12)
self.ax.set_xlabel(xlabel, fontsize=12)
self.ax.set_ylabel(
ylabel,
fontsize=12,
)
return plt.gca()
def _plot_emission_mpl(self):
"""Plot emission part of the SDEC Plot using matplotlib."""
# To create stacked area chart in matplotlib, we will start with zero
# lower level and will keep adding luminosities to it (upper level)
lower_level = np.zeros(self.emission_luminosities_df.shape[0])
upper_level = (
lower_level + self.emission_luminosities_df.noint.to_numpy()
)
self.ax.fill_between(
self.plot_wavelength.value,
lower_level,
upper_level,
color="#4C4C4C",
label="No interaction",
)
lower_level = upper_level
upper_level = (
lower_level + self.emission_luminosities_df.escatter.to_numpy()
)
self.ax.fill_between(
self.plot_wavelength.value,
lower_level,
upper_level,
color="#8F8F8F",
label="Electron Scatter Only",
)
# If the 'other' column exists then plot it as silver
if "other" in self.emission_luminosities_df.keys():
lower_level = upper_level
upper_level = (
lower_level + self.emission_luminosities_df.other.to_numpy()
)
self.ax.fill_between(
self.plot_wavelength.value,
lower_level,
upper_level,
color="#C2C2C2",
label="Other elements",
)
# Contribution from each element
for species_counter, identifier in enumerate(self.species):
try:
lower_level = upper_level
upper_level = (
lower_level
+ self.emission_luminosities_df[identifier].to_numpy()
)
self.ax.fill_between(
self.plot_wavelength.value,
lower_level,
upper_level,
color=self._color_list[species_counter],
cmap=self.cmap,
linewidth=0,
)
except KeyError:
# Add notifications that this species was not in the emission df
self._log_missing_species(identifier, "emitted")
def _plot_absorption_mpl(self):
"""Plot absorption part of the SDEC Plot using matplotlib."""
lower_level = np.zeros(self.absorption_luminosities_df.shape[0])
# To plot absorption part along -ve X-axis, we will start with
# zero upper level and keep subtracting luminosities to it (lower
# level) - fill from upper to lower level
# If the 'other' column exists then plot it as silver
if "other" in self.absorption_luminosities_df.keys():
upper_level = lower_level
lower_level = (
upper_level - self.absorption_luminosities_df.other.to_numpy()
)
self.ax.fill_between(
self.plot_wavelength.value,
upper_level,
lower_level,
color="silver",
)
for species_counter, identifier in enumerate(self.species):
try:
upper_level = lower_level
lower_level = (
upper_level
- self.absorption_luminosities_df[identifier].to_numpy()
)
self.ax.fill_between(
self.plot_wavelength.value,
upper_level,
lower_level,
color=self._color_list[species_counter],
cmap=self.cmap,
linewidth=0,
)
except KeyError:
# Add notifications that this species was not in the emission df
self._log_missing_species(identifier, "absorbed")
def _show_colorbar_mpl(self):
"""Show matplotlib colorbar with labels of elements mapped to colors."""
color_values = [
self.cmap(species_counter / len(self._species_name))
for species_counter in range(len(self._species_name))
]
custcmap = clr.ListedColormap(color_values)
norm = clr.Normalize(vmin=0, vmax=len(self._species_name))
mappable = cm.ScalarMappable(norm=norm, cmap=custcmap)
mappable.set_array(np.linspace(1, len(self._species_name) + 1, 256))
cbar = plt.colorbar(mappable, ax=self.ax)
bounds = np.arange(len(self._species_name)) + 0.5
cbar.set_ticks(bounds)
cbar.set_ticklabels(self._species_name)
def _make_colorbar_labels(self):
"""Get the labels for the species in the colorbar."""
if self._species_list is None:
# If species_list is none then the labels are just elements
species_name = [
atomic_number2element_symbol(atomic_num)
for atomic_num in self.species
]
else:
species_name = []
for element in self.species:
# Go through each species requested
atomic_number, ion_number = divmod(element, 100)
ion_numeral = int_to_roman(ion_number + 1)
atomic_symbol = atomic_number2element_symbol(atomic_number)
# if the element was requested, and not a specific ion, then
# add the element symbol to the label list
if (atomic_number in self._keep_colour) and (
atomic_symbol not in species_name
):
# compiling the label, and adding it to the list
label = atomic_symbol
species_name.append(label)
elif atomic_number not in self._keep_colour:
# otherwise add the ion to the label list
label = f"{atomic_symbol} {ion_numeral}"
species_name.append(label)
self._species_name = species_name
def _make_colorbar_colors(self):
"""Get the colours for the species to be plotted."""
color_list = []
# - For elements in self._keep_colour, all ionization states share the same color
# (e.g., Si I, Si II, Si III all get the same color if Si's atomic number is in self._keep_colour)
# - For elements not in self._keep_colour, each ionization state gets a new color
for i, identifier in enumerate(self.species):
if self._species_list is not None:
color_counter = 0
atomic_number = identifier // 100
# For any element after the first one
if i > 0:
previous_atomic_number = self.species[i - 1] // 100
# Increment color when:
# 1. There is a new element, OR
# 2. The previous element isn't in the keep_colour list
if (
previous_atomic_number != atomic_number
or previous_atomic_number not in self._keep_colour
):
color_counter += 1
color = self.cmap(color_counter / len(self._species_name))
else:
color = self.cmap(i / len(self.species))
color_list.append(color)
self._color_list = color_list
[docs]
def generate_plot_ply(
self,
packets_mode="virtual",
packet_wvl_range=None,
distance=None,
observed_spectrum=None,
show_modeled_spectrum=True,
fig=None,
graph_height=600,
cmapname="jet",
nelements=None,
species_list=None,
blackbody_photosphere=True,
):
"""
Generate interactive Spectral element DEComposition (SDEC) Plot using plotly.
Parameters
----------
packets_mode : {'virtual', 'real'}, optional
Mode of packets to be considered, either real or virtual. Default
value is 'virtual'
packet_wvl_range : astropy.Quantity or None, optional
Wavelength range to restrict the analysis of escaped packets. It
should be a quantity having units of Angstrom, containing two
values - lower lambda and upper lambda i.e.
[lower_lambda, upper_lambda] * u.AA. Default value is None
distance : astropy.Quantity or None, optional
Distance used to calculate flux instead of luminosity in the plot.
It should have a length unit like m, Mpc, etc. Default value is None
observed_spectrum : tuple or list of astropy.Quantity, optional
Option to plot an observed spectrum in the SDEC plot. If given, the first element
should be the wavelength and the second element should be flux,
i.e. (wavelength, flux). The assumed units for wavelength and flux are
angstroms and erg/(angstroms * s * cm^2), respectively. Default value is None.
show_modeled_spectrum : bool, optional
Whether to show modeled spectrum in SDEC Plot. Default value is
True
fig : plotly.graph_objs._figure.Figure or None, optional
Figure object on which to create plot. Default value is None which
will create plot on a new Figure object.
graph_height : int, optional
Height (in px) of the plotly graph to display. Default value is 600
cmapname : str, optional
Name of the colormap to be used for showing elements.
Default value is "jet"
nelements: int
Number of elements to include in plot. Determined by the
largest contribution to total luminosity absorbed and emitted.
Other elements are shown in silver. Default value is
None, which displays all elements
species_list: list of strings or None
list of strings containing the names of species that should be included in the SDEC plots.
Must be given in Roman numeral format. Can include specific ions, a range of ions,
individual elements, or any combination of these:
e.g. ['Si II', 'Ca II', 'C', 'Fe I-V']
blackbody_photosphere: bool
Whether to include the blackbody photosphere in the plot. Default value is True
Returns
-------
plotly.graph_objs._figure.Figure
Figure object on which SDEC Plot is created
"""
# If species_list and nelements requested, tell user that nelements is ignored
if species_list is not None and nelements is not None:
logger.info(
"Both nelements and species_list were requested. Species_list takes priority; nelements is ignored"
)
hover_props = {
"hoverlabel": {"namelength": -1},
"hovertemplate": "(%{x:.2f}, %{y:.3g})",
}
# Parse the requested species list
self._parse_species_list(species_list=species_list)
# Calculate data attributes required for plotting
# and save them in instance itself
self._calculate_plotting_data(
packets_mode=packets_mode,
packet_wvl_range=packet_wvl_range,
distance=distance,
nelements=nelements,
)
if fig is None:
self.fig = go.Figure()
else:
self.fig = fig
# Get the labels in the color bar. This determines the number of unique colors
self._make_colorbar_labels()
# Set colormap to be used in elements of emission and absorption plots
self.cmap = plt.get_cmap(cmapname, len(self._species_name))
# Get the number of unique colors
self._make_colorbar_colors()
# Plot absorption and emission components
self._plot_emission_ply()
self._plot_absorption_ply()
# Plot modeled spectrum
if show_modeled_spectrum:
self.fig.add_trace(
go.Scatter(
x=self.plot_wavelength.value,
y=self.modeled_spectrum_luminosity.value,
mode="lines",
line={
"color": "blue",
"width": 1,
},
name=f"{packets_mode.capitalize()} Spectrum",
**hover_props,
)
)
# Plot observed spectrum
if observed_spectrum:
if distance is None:
raise ValueError(
"""
Distance must be specified if an observed_spectrum is given
so that the model spectrum can be converted into flux space correctly.
"""
)
# Convert to wavelength and luminosity units
observed_spectrum_wavelength = observed_spectrum[0].to(u.AA)
observed_spectrum_flux = observed_spectrum[1].to("erg/(s cm**2 AA)")
self.fig.add_scatter(
x=observed_spectrum_wavelength.value,
y=observed_spectrum_flux.value,
name="Observed Spectrum",
line={"color": "black", "width": 1.2},
**hover_props,
)
# Plot photosphere
if blackbody_photosphere:
self.fig.add_trace(
go.Scatter(
x=self.plot_wavelength.value,
y=self.photosphere_luminosity.value,
mode="lines",
line={"width": 1.5, "color": "red", "dash": "dash"},
name="Blackbody Photosphere",
**hover_props,
)
)
self._show_colorbar_ply()
# Set label and other layout options
xlabel = pu.axis_label_in_latex("Wavelength", u.AA)
if distance is not None: # Set y-axis label for flux
ylabel = pu.axis_label_in_latex(
"F_{\\lambda}", u.Unit("erg/(s cm**2 AA)"), only_text=False
)
else: # Set y-axis label for luminosity
ylabel = pu.axis_label_in_latex(
"L_{\\lambda}", u.Unit("erg/(s AA)"), only_text=False
)
self.fig.update_layout(
xaxis={
"title": xlabel,
"exponentformat": "none",
},
yaxis={"title": ylabel, "exponentformat": "e"},
height=graph_height,
)
return self.fig
def _plot_emission_ply(self):
"""Plot emission part of the SDEC Plot using plotly."""
self._plot_traces(
df=self.emission_luminosities_df,
group_name="emission",
predefined_traces=self._predefined_traces["emission"],
invert_y=False,
)
def _plot_absorption_ply(self):
"""Plot absorption part of the SDEC Plot using plotly."""
self._plot_traces(
df=self.absorption_luminosities_df,
group_name="absorption",
predefined_traces=self._predefined_traces["absorption"],
invert_y=True,
)
def _show_colorbar_ply(self):
"""Show plotly colorbar with labels of elements mapped to colors."""
# Interpolate [0, 1] range to create bins equal to number of elements
colorscale_bins = np.linspace(0, 1, num=len(self._species_name) + 1)
# Create a categorical colorscale [a list of (reference point, color)]
# by mapping same reference points (excluding 1st and last bin edge)
# twice in a row (https://plotly.com/python/colorscales/#constructing-a-discrete-or-discontinuous-color-scale)
categorical_colorscale = []
for species_counter in range(len(self._species_name)):
color = pu.to_rgb255_string(
self.cmap(colorscale_bins[species_counter])
)
categorical_colorscale.append(
(colorscale_bins[species_counter], color)
)
categorical_colorscale.append(
(colorscale_bins[species_counter + 1], color)
)
coloraxis_options = {
"colorscale": categorical_colorscale,
"showscale": True,
"cmin": 0,
"cmax": len(self._species_name),
"colorbar": {
"title": "Elements",
"tickvals": np.arange(0, len(self._species_name)) + 0.5,
"ticktext": self._species_name,
# to change length and position of colorbar
"len": 0.75,
"yanchor": "top",
"y": 0.75,
},
}
# Plot an invisible one point scatter trace, to make colorbar show up
scatter_point_idx = pu.get_mid_point_idx(self.plot_wavelength)
self.fig.add_trace(
go.Scatter(
x=[self.plot_wavelength[scatter_point_idx].value],
y=[0],
mode="markers",
name="Colorbar",
showlegend=False,
hoverinfo="skip",
marker=dict(color=[0], opacity=0, **coloraxis_options),
)
)
def _log_missing_species(self, identifier, is_absorption):
"""
Log an informational message when a species is missing from interaction data.
Parameters
----------
identifier : int
Species identifier, atomic number or combined atomic and ion number (Z*100+ion).
is_absorption : bool
True if checking absorption species, False for emission species.
"""
interaction_type = "absorbed" if is_absorption else "emitted"
if self._species_list is None:
info_msg = (
f"{atomic_number2element_symbol(identifier)}"
f" is not in the {interaction_type} packets; skipping"
)
else:
# Get the ion number and atomic number for each species
atomic_number, ion_number = divmod(
identifier, 100
) # (quotient, remainder)
info_msg = (
f"{atomic_number2element_symbol(atomic_number)}"
f"{int_to_roman(ion_number + 1)}"
f" is not in the {interaction_type} packets; skipping"
)
logger.info(info_msg)