Source code for tardis.visualization.tools.liv_plot

import logging
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import plotly.graph_objects as go
import numpy as np
import pandas as pd
import astropy.units as u

from tardis.util.base import (
    atomic_number2element_symbol,
    int_to_roman,
)
import tardis.visualization.tools.sdec_plot as sdec
from tardis.visualization import plot_util as pu

logger = logging.getLogger(__name__)


[docs] class LIVPlotter: """ Plotting interface for the last interaction velocity plot. """ def __init__(self, data, time_explosion, velocity): """ Initialize the plotter with required data from the simulation. Parameters ---------- data : dict of SDECData Dictionary to store data required for last interaction velocity plot, for both packet modes (real, virtual). time_explosion : astropy.units.Quantity Time of the explosion. velocity : astropy.units.Quantity Velocity array from the simulation. """ self.data = data self.time_explosion = time_explosion self.velocity = velocity self.sdec_plotter = sdec.SDECPlotter(data)
[docs] @classmethod def from_simulation(cls, sim): """ Create an instance of the plotter from a TARDIS simulation object. Parameters ---------- sim : tardis.simulation.Simulation TARDIS simulation object produced by running a simulation. Returns ------- LIVPlotter """ return cls( dict( virtual=sdec.SDECData.from_simulation(sim, "virtual"), real=sdec.SDECData.from_simulation(sim, "real"), ), sim.plasma.time_explosion, sim.simulation_state.velocity, )
[docs] @classmethod def from_hdf(cls, hdf_fpath): """ Create an instance of the Plotter from a simulation HDF file. Parameters ---------- hdf_fpath : str Valid path to the HDF file where simulation is saved. Returns ------- LIVPlotter """ with pd.HDFStore(hdf_fpath, "r") as hdf: time_explosion = ( hdf["/simulation/plasma/scalars"]["time_explosion"] * u.s ) v_inner = hdf["/simulation/simulation_state/v_inner"] * (u.cm / u.s) v_outer = hdf["/simulation/simulation_state/v_outer"] * (u.cm / u.s) velocity = pd.concat( [v_inner, pd.Series([v_outer.iloc[-1]])], ignore_index=True ).tolist() * (u.cm / u.s) return cls( dict( virtual=sdec.SDECData.from_hdf(hdf_fpath, "virtual"), real=sdec.SDECData.from_hdf(hdf_fpath, "real"), ), time_explosion, velocity, )
def _parse_species_list(self, species_list, packets_mode, nelements=None): """ Parse user requested species list and create list of species ids to be used. Parameters ---------- species_list : list of species to plot List of species (e.g. Si II, Ca II, etc.) that the user wants to show as unique colours. Species can be given as an ion (e.g. Si II), an element (e.g. Si), a range of ions (e.g. Si I - V), or any combination of these (e.g. species_list = [Si II, Fe I-V, Ca]) packets_mode : str, optional Packet mode, either 'virtual' or 'real'. Default is 'virtual'. nelements : int, optional Number of elements to include in plot. The most interacting elements are included. If None, displays all elements. Raises ------ ValueError If species list contains invalid entries. """ self.sdec_plotter._parse_species_list(species_list) self._species_list = self.sdec_plotter._species_list self._species_mapped = self.sdec_plotter._species_mapped self._keep_colour = self.sdec_plotter._keep_colour if nelements: interaction_counts = ( self.data[packets_mode] .packets_df_line_interaction["last_line_interaction_species"] .value_counts() ) interaction_counts.index = interaction_counts.index // 100 element_counts = interaction_counts.groupby( interaction_counts.index ).sum() top_elements = element_counts.nlargest(nelements).index top_species_list = [ atomic_number2element_symbol(element) for element in top_elements ] self._parse_species_list(top_species_list, packets_mode) def _make_colorbar_labels(self): """ Generate labels for the colorbar based on species. If a species list is provided, uses that to generate labels. Otherwise, generates labels from the species in the model. """ if self._species_list is None: species_name = [ atomic_number2element_symbol(atomic_num) for atomic_num in self.species ] else: species_name = [] for species_key, species_ids in self._species_mapped.items(): if any(species in self.species for species in species_ids): if species_key % 100 == 0: label = atomic_number2element_symbol(species_key // 100) else: atomic_number = species_key // 100 ion_number = species_key % 100 ion_numeral = int_to_roman(ion_number + 1) label = f"{atomic_number2element_symbol(atomic_number)} {ion_numeral}" species_name.append(label) self._species_name = species_name def _make_colorbar_colors(self): """ Generate colors for the species to be plotted. This method creates a list of colors corresponding to the species names. The colors are generated based on the species present in the model and the requested species list. """ color_list = [] species_keys = list(self._species_mapped.keys()) num_species = len(species_keys) for species_counter, species_key in enumerate(species_keys): if any( species in self.species for species in self._species_mapped[species_key] ): color = self.cmap(species_counter / num_species) color_list.append(color) self._color_list = color_list def _generate_plot_data(self, packets_mode): """ Generate plot data and colors for species in the model. Parameters ---------- packets_mode : str Packet mode, either 'virtual' or 'real'. """ groups = ( self.data[packets_mode] .packets_df_line_interaction.loc[self.packet_nu_line_range_mask] .groupby(by="last_line_interaction_species") ) self.plot_colors = [] self.plot_data = [] species_not_wvl_range = [] species_counter = 0 for specie_list in self._species_mapped.values(): full_v_last = [] for specie in specie_list: if specie in self.species: if specie not in groups.groups: atomic_number = specie // 100 ion_number = specie % 100 ion_numeral = int_to_roman(ion_number + 1) label = f"{atomic_number2element_symbol(atomic_number)} {ion_numeral}" species_not_wvl_range.append(label) continue g_df = groups.get_group(specie) r_last_interaction = ( g_df["last_interaction_in_r"].values * u.cm ) v_last_interaction = ( r_last_interaction / self.time_explosion ).to("km/s") full_v_last.extend(v_last_interaction) if full_v_last: self.plot_data.append(full_v_last) self.plot_colors.append(self._color_list[species_counter]) species_counter += 1 if species_not_wvl_range: logger.info( "%s were not found in the provided wavelength range.", species_not_wvl_range, ) def _prepare_plot_data( self, packets_mode, packet_wvl_range, species_list, cmapname, num_bins, nelements, ): """ Prepare data and settings required for generating a plot. This method handles the common logic for preparing data and settings needed to generate both matplotlib and plotly plots. It parses the species list, generates color labels and colormap, and bins the velocity data. Parameters ---------- packets_mode : str Packet mode, either 'virtual' or 'real'. packet_wvl_range : astropy.Quantity Wavelength range to restrict the analysis of escaped packets. It should be a quantity having units of Angstrom, containing two values - lower lambda and upper lambda i.e. [lower_lambda, upper_lambda] * u.AA species_list : list of str List of species to plot. Species can be specified as an ion (e.g., Si II), an element (e.g., Si), a range of ions (e.g., Si I-V), or any combination of these. cmapname : str Name of the colormap to use. A specific colormap can be chosen, such as "jet", "viridis", "plasma", etc. num_bins : int, optional Number of bins for regrouping within the same range. If None, no regrouping is done. nelements : int, optional Number of elements to include in plot. The most interacting elements are included. If None, displays all elements. Raises ------ ValueError If no species are provided for plotting, or if no valid species are found in the model. """ if species_list is None: # Extract all unique elements from the packets data species_in_model = np.unique( self.data[packets_mode] .packets_df_line_interaction["last_line_interaction_species"] .values ) species_list = [ f"{atomic_number2element_symbol(specie // 100)}" for specie in species_in_model ] self._parse_species_list(species_list, packets_mode, nelements) species_in_model = np.unique( self.data[packets_mode] .packets_df_line_interaction["last_line_interaction_species"] .values ) if self._species_list is None or not self._species_list: raise ValueError("No species provided for plotting.") msk = np.isin(self._species_list, species_in_model) self.species = np.array(self._species_list)[msk] if len(self.species) == 0: raise ValueError("No valid species found for plotting.") self._make_colorbar_labels() self.cmap = plt.get_cmap(cmapname, len(self._species_name)) self._make_colorbar_colors() if packet_wvl_range is None: self.packet_nu_line_range_mask = np.ones( self.data[packets_mode].packets_df_line_interaction.shape[0], dtype=bool, ) else: packet_nu_range = [ value.to("Hz", equivalencies=u.spectral()) for value in packet_wvl_range ] self.packet_nu_line_range_mask = ( self.data[packets_mode].packets_df_line_interaction["nus"] >= packet_nu_range[1] ) & ( self.data[packets_mode].packets_df_line_interaction["nus"] <= packet_nu_range[0] ) self._generate_plot_data(packets_mode) bin_edges = (self.velocity).to("km/s") if num_bins: if num_bins < 1: raise ValueError("Number of bins must be positive") elif num_bins > len(bin_edges) - 1: logger.warning( "Number of bins must be less than or equal to number of shells. Plotting with number of bins equals to number of shells." ) self.new_bin_edges = bin_edges else: self.new_bin_edges = np.linspace( bin_edges[0], bin_edges[-1], num_bins + 1 ) else: self.new_bin_edges = bin_edges def _get_step_plot_data(self, data, bin_edges): """ Generate step plot data from histogram data. Parameters ---------- data : array-like Data to be binned into a histogram. bin_edges : array-like Edges of the bins for the histogram. """ hist, _ = np.histogram(data, bins=bin_edges) self.step_x = np.repeat(bin_edges, 2)[1:-1] self.step_y = np.repeat(hist, 2)
[docs] def generate_plot_mpl( self, species_list=None, nelements=None, packets_mode="virtual", packet_wvl_range=None, ax=None, figsize=(11, 5), cmapname="jet", xlog_scale=False, ylog_scale=False, num_bins=None, velocity_range=None, ): """ Generate the last interaction velocity distribution plot using matplotlib. Parameters ---------- species_list : list of str, optional List of species to plot. Default is None which plots all species in the model. nelements : int, optional Number of elements to include in plot. The most interacting elements are included. If None, displays all elements. packets_mode : str, optional Packet mode, either 'virtual' or 'real'. Default is 'virtual'. packet_wvl_range : astropy.Quantity Wavelength range to restrict the analysis of escaped packets. It should be a quantity having units of Angstrom, containing two values - lower lambda and upper lambda i.e. [lower_lambda, upper_lambda] * u.AA ax : matplotlib.axes.Axes, optional Axes object to plot on. If None, creates a new figure. figsize : tuple, optional Size of the figure. Default is (11, 5). cmapname : str, optional Colormap name. Default is 'jet'. A specific colormap can be chosen, such as "jet", "viridis", "plasma", etc. xlog_scale : bool, optional If True, x-axis is scaled logarithmically. Default is False. ylog_scale : bool, optional If True, y-axis is scaled logarithmically. Default is False. num_bins : int, optional Number of bins for regrouping within the same range. Default is None. velocity_range : tuple, optional Limits for the x-axis. If specified, overrides any automatically determined limits. Returns ------- matplotlib.axes.Axes Axes object with the plot. """ # If species_list and nelements requested, tell user that nelements is ignored if species_list is not None and nelements is not None: logger.info( "Both nelements and species_list were requested. Species_list takes priority; nelements is ignored" ) nelements = None self._prepare_plot_data( packets_mode, packet_wvl_range, species_list, cmapname, num_bins, nelements, ) bin_edges = self.new_bin_edges if ax is None: self.ax = plt.figure(figsize=figsize).add_subplot(111) else: self.ax = ax for data, color, name in zip( self.plot_data, self.plot_colors, self._species_name ): self._get_step_plot_data(data, bin_edges) self.ax.plot( self.step_x, self.step_y, label=name, color=color, linewidth=2.5, drawstyle="steps-post", alpha=0.75, ) self.ax.ticklabel_format(axis="y", scilimits=(0, 0)) self.ax.tick_params("both", labelsize=15) self.ax.set_xlabel("Last Interaction Velocity (km/s)", fontsize=14) self.ax.set_ylabel("Packet Count", fontsize=15) self.ax.legend(fontsize=15, bbox_to_anchor=(1.0, 1.0), loc="upper left") self.ax.figure.tight_layout() if xlog_scale: self.ax.set_xscale("log") if ylog_scale: self.ax.set_yscale("log") if velocity_range: self.ax.set_xlim(velocity_range[0], velocity_range[1]) return self.ax
[docs] def generate_plot_ply( self, species_list=None, nelements=None, packets_mode="virtual", packet_wvl_range=None, fig=None, graph_height=600, cmapname="jet", xlog_scale=False, ylog_scale=False, num_bins=None, velocity_range=None, ): """ Generate the last interaction velocity distribution plot using plotly. Parameters ---------- species_list : list of str, optional List of species to plot. Default is None which plots all species in the model. nelements : int, optional Number of elements to include in plot. The most interacting elements are included. If None, displays all elements. packets_mode : str, optional Packet mode, either 'virtual' or 'real'. Default is 'virtual'. packet_wvl_range : astropy.Quantity Wavelength range to restrict the analysis of escaped packets. It should be a quantity having units of Angstrom, containing two values - lower lambda and upper lambda i.e. [lower_lambda, upper_lambda] * u.AA fig : plotly.graph_objects.Figure, optional Plotly figure object to add the plot to. If None, creates a new figure. graph_height : int, optional Height (in px) of the plotly graph to display. Default value is 600. cmapname : str, optional Colormap name. Default is 'jet'. A specific colormap can be chosen, such as "jet", "viridis", "plasma", etc. xlog_scale : bool, optional If True, x-axis is scaled logarithmically. Default is False. ylog_scale : bool, optional If True, y-axis is scaled logarithmically. Default is False. num_bins : int, optional Number of bins for regrouping within the same range. Default is None. velocity_range : tuple, optional Limits for the x-axis. If specified, overrides any automatically determined limits. Returns ------- plotly.graph_objects.Figure Plotly figure object with the plot. """ # If species_list and nelements requested, tell user that nelements is ignored if species_list is not None and nelements is not None: logger.info( "Both nelements and species_list were requested. Species_list takes priority; nelements is ignored" ) nelements = None self._prepare_plot_data( packets_mode, packet_wvl_range, species_list, cmapname, num_bins, nelements, ) bin_edges = self.new_bin_edges if fig is None: self.fig = go.Figure() else: self.fig = fig for data, color, name in zip( self.plot_data, self.plot_colors, self._species_name ): self._get_step_plot_data(data, bin_edges) self.fig.add_trace( go.Scatter( x=self.step_x, y=self.step_y, mode="lines", line=dict( color=pu.to_rgb255_string(color), width=2.5, shape="hv", ), name=name, opacity=0.75, ) ) self.fig.update_layout( height=graph_height, xaxis_title="Last Interaction Velocity (km/s)", yaxis_title="Packet Count", font=dict(size=15), yaxis=dict(exponentformat="power" if ylog_scale else "e"), xaxis=dict(exponentformat="power" if xlog_scale else "none"), ) if xlog_scale: self.fig.update_xaxes(type="log") if ylog_scale: self.fig.update_yaxes(type="log", dtick=1) if velocity_range: self.fig.update_xaxes(range=velocity_range) return self.fig