"""
Grotrian Diagram Widget for TARDIS simulation models.
This widget displays a Grotrian Diagram of the last line interactions of the simulation packets
"""
from tardis.analysis import LastLineInteraction
from tardis.util.base import species_tuple_to_string, species_string_to_tuple
from tardis.util.base import int_to_roman
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
from astropy import units as u
import ipywidgets as ipw
ANGSTROM_SYMBOL = "\u212B"
[docs]
def is_zero_defined(transform):
"""
Utility function to decide if a certain transform is defined at zero
Parameters
----------
transform : function
Returns
-------
bool
True if transform is defined at 0 else False
"""
if transform in [np.log, np.log10]:
return True
return False
[docs]
def standardize(
values,
transform=lambda x: x,
min_value=None,
max_value=None,
zero_undefined_offset=0,
):
"""
Utility function to standardize displayed values like wavelengths, num_packets, levels populations to the range [0, 1]
This helps in computing visual elements like widths, colors, etc.
Parameters
----------
values : pandas.Series
The data to standardize
transform : function, optional
Transformations like np.log, np.exp, etc. to apply on the data. Defaults to identity
min_value : float, optional
The lower bound of the range
max_value : float, optional
The upper bound of the range
zero_undefined_offset : int, optional
This is useful for log transformation because log(0) is -inf.
Hence, value=0 gives y=0 while the
output for other values start at `zero_undefined_offset` (y = log(value) + zero_undefined_offset)
Default value is 0
Returns
-------
pandas.Series
Values after standardization
"""
zero_undefined = is_zero_defined(transform) # Is function defined at 0?
if zero_undefined and zero_undefined_offset == 0:
raise ValueError(
"If zero of the transformation is undefined, then provide an offset greater than 0"
)
# Compute lower and upper bounds of values
if min_value is None:
if zero_undefined:
min_value = (
values[values > 0].min() if len(values[values > 0]) > 0 else 0
)
else:
min_value = values.min() if len(values) > 0 else 0
if max_value is None:
if zero_undefined:
max_value = (
values[values > 0].max() if len(values[values > 0]) > 0 else 0
)
else:
max_value = values.max() if len(values) > 0 else 0
# Apply transformation if given
transformed_min_value = (
transform(min_value) if (min_value > 0 or not zero_undefined) else 0
)
transformed_max_value = (
transform(max_value) if (max_value > 0 or not zero_undefined) else 0
)
transformed_values = transform(values)
# Compute range
value_range = transformed_max_value - transformed_min_value
# Apply standardization
if value_range > 0:
transformed_values = (
transformed_values - transformed_min_value
) / value_range
if zero_undefined:
transformed_values = transformed_values + zero_undefined_offset
transformed_values = np.where(values == 0, 0, transformed_values)
else:
# If only single value present in table, then place it at 0
transformed_values = 0 * values
return transformed_values
[docs]
class GrotrianPlot:
"""
Class for the Grotrian Diagram
Parameters
----------
atom_data : pandas.DataFrame
Mapping from atomic number to symbol and name
level_energy_data : pandas.Series
Level energies (in eV) indexed by (atomic_number, ion_number, level_number)
level_population_data : pandas.DataFrame
Level populations indexed by (atomic_number, ion_number, level_number)
and each column representing the supernova shell
line_interaction_analysis : tardis.analysis.LastLineInteraction
LastLineInteraction object with the appropriate filters
Configurable Attributes
-----------------------
atomic_number : int
Atomic number of the ion for which the diagram is plotted
Note: User should set the atomic_number and ion_number together using set_ion function.
ion_number : int
Ion number of the ion for which the diagram is plotted
Note: User should set the atomic_number and ion_number together using set_ion function.
shell : int or None
The supernova shell to filter on.
If None, the level populations are averaged across all shells,
and all last line interaction are considered
Default value is None
max_levels : int
The maximum number of levels to plot.
Default value is 10
level_diff_threshold : float
The percentage threshold under which levels are merged
Default value is 1% (0.01)
min_wavelength : float
The minimum wavelength allowed for the transitions
max_wavelength : float
The maximum wavelength allowed for the transitions
filter_mode : {"packet_out_nu", "packet_in_nu"}
The type of wavelength to apply wavelength range filter on
Default value is packet_out_nu
y_scale : {"Log", "Linear"}
The scale to plot the energy levels on the y-axis
Default value is Log
cmapname : str
The name of the colormap used to denote wavelengths. Default value is "rainbow"
level_width_scale : float
The multiplier to convert standardized level populations to level widths
Default value is 3
level_width_offset : float
The offset for level widths (to add to the scaled standardized level populations)
Default value is 1
transition_width_scale : float
The multiplier to convert standardized packet count to transition widths
Default value is 2
transition_width_offset : float
The offset for transition widths (to add to the scaled standardized packet counts)
Default value is 1
"""
FILTER_MODES = ("packet_out_nu", "packet_in_nu")
FILTER_MODES_DESC = ("Emitted Wavelength", "Absorbed Wavelength")
Y_SCALE_OPTION = {"Linear": (lambda x: x), "Log": np.log}
[docs]
@classmethod
def from_simulation(cls, sim, **kwargs):
"""
Creates a GrotrianPlot object from a Simulation object
Parameters
----------
sim : tardis.simulation.Simulation
TARDIS simulation object
Returns
-------
tardis.visualization.widgets.grotrian.GrotrianPlot
GrotrianPlot object
"""
atom_data = sim.plasma.atomic_data.atom_data
level_energy_data = pd.Series(
sim.plasma.atomic_data.levels.energy * u.erg.to(u.electronvolt),
name="energy",
)
level_population_data = sim.plasma.level_number_density
line_interaction_analysis = {
filter_mode: LastLineInteraction.from_simulation(sim, filter_mode)
for filter_mode in cls.FILTER_MODES
}
return cls(
atom_data=atom_data,
level_energy_data=level_energy_data,
level_population_data=level_population_data,
line_interaction_analysis=line_interaction_analysis,
**kwargs,
)
def __init__(
self,
atom_data,
level_energy_data,
level_population_data,
line_interaction_analysis,
):
# Set data members
self._atom_data = atom_data
self._level_energy_data = level_energy_data
self._level_population_data = level_population_data
self._line_interaction_analysis = line_interaction_analysis
# Max number of levels to display
self._max_levels = 10
# Energy difference threshold below which levels are merged
self._level_diff_threshold = 0.01
# Filter mode for the wavelength range
self._min_wavelength = None
self._max_wavelength = None
self._filter_mode = self.FILTER_MODES[0]
# Selected Species
self._atomic_number = None
self._ion_number = None
self._shell = None
### Define default parameters for visual elements related to energy levels
self.level_width_scale, self.level_width_offset = 3, 1
self._level_width_transform = np.log # Scale of the level widths
self._population_spacer = np.geomspace # To space width bar counts
### Scale of the y-axis
self._y_scale = "Log"
self._y_coord_transform = self.Y_SCALE_OPTION[self._y_scale]
### Define default parameters for visual elements related to transitions
self.transition_width_scale, self.transition_width_offset = 2, 1
self._transition_width_transform = np.log # Scale of the arrow widths
self._transition_count_spacer = (
np.geomspace
) # To space width bar counts
self.arrowhead_size = 9
### Define default parameters for visual elements related to wavelengths
self.cmapname = "rainbow"
self._wavelength_color_transform = np.log # Scale of wavelength color
self._wavelength_spacer = np.geomspace # To space colorbar wavelengths
# Coordinate end points of levels
self.x_min, self.x_max = 0, 1
@property
def min_wavelength(self):
return self._min_wavelength
@min_wavelength.setter
def min_wavelength(self, value):
self._min_wavelength = value
self._compute_transitions()
@property
def max_wavelength(self):
return self._max_wavelength
@max_wavelength.setter
def max_wavelength(self, value):
self._max_wavelength = value
self._compute_transitions()
[docs]
def reset_selected_plot_wavelength_range(self):
"""
Resets the wavelength range of the selected plot
"""
self.min_wavelength = None
self.max_wavelength = None
@property
def max_levels(self):
return self._max_levels
@max_levels.setter
def max_levels(self, value):
assert type(value) is int
self._max_levels = value
self._compute_level_data()
self.reset_selected_plot_wavelength_range() # calls _compute_transitions() as well
@property
def level_diff_threshold(self):
return self._level_diff_threshold
@level_diff_threshold.setter
def level_diff_threshold(self, value):
assert 0 >= value and value < 1
self._level_diff_threshold = value
self._compute_level_data()
self._compute_transitions()
@property
def filter_mode(self):
return self._filter_mode
@filter_mode.setter
def filter_mode(self, value):
assert value in self.FILTER_MODES
# Set the atomic_number and ion_number in the appropriate analysis object
self._line_interaction_analysis[value].set_ion(
self.atomic_number, self.ion_number
)
self._line_interaction_analysis[value].shell = self.shell
self._filter_mode = value
self._compute_transitions()
@property
def atomic_number(self):
if self._atomic_number is None:
raise ValueError("Atomic number is not set")
return self._atomic_number
[docs]
def set_ion(self, atomic_number, ion_number):
"""
Sets the atomic number and ion number
"""
assert type(atomic_number) is int and type(ion_number) is int
if (atomic_number, ion_number) not in self._level_energy_data.index or (
atomic_number,
ion_number,
) not in self._level_population_data.index:
raise ValueError(
"The (atomic_number, ion_number) pair doesn't exist in model"
)
self._line_interaction_analysis[self.filter_mode].set_ion(
atomic_number, ion_number
)
self._atomic_number = atomic_number
self._ion_number = ion_number
self._compute_level_data()
# Reset any custom wavelengths if user changes ion
self.reset_selected_plot_wavelength_range() # Also computes transition lines so we don't need to call it "_compute_transitions()" explicitly
@property
def ion_number(self):
if self._ion_number is None:
raise ValueError("Ion number is not set")
return self._ion_number
@property
def atomic_name(self):
return self._atom_data.loc[self.atomic_number]["name"]
@property
def atomic_symbol(self):
return self._atom_data.loc[self.atomic_number]["symbol"]
@property
def shell(self):
return self._shell
@shell.setter
def shell(self, value):
assert value is None or type(value) is int
self._line_interaction_analysis[self.filter_mode].shell = value
self._shell = value
self._compute_level_data()
self._compute_transitions()
@property
def y_scale(self):
return self._y_scale
@y_scale.setter
def y_scale(self, value):
assert value in self.Y_SCALE_OPTION
self._y_scale = value
self._y_coord_transform = self.Y_SCALE_OPTION[self._y_scale]
def _compute_transitions(self):
"""
Computes the excitation/de-excitation line transition data for the arrows in the widget
"""
### Get the excitation/de-excitation transitions from LastLineInteraction object
excite_lines = (
self._line_interaction_analysis[self.filter_mode]
.last_line_in.reset_index()
.groupby(["level_number_lower", "level_number_upper"])
.agg(
num_electrons=("line_id", "count"), # Take count of lines
wavelength=("wavelength", "first"), # Take first of wavelengths
)
.reset_index()
)
deexcite_lines = (
self._line_interaction_analysis[self.filter_mode]
.last_line_out.reset_index()
.groupby(["level_number_lower", "level_number_upper"])
.agg(
num_electrons=("line_id", "count"), # Take count of lines
wavelength=("wavelength", "first"), # Take first of wavelengths
)
.reset_index()
)
### Filter transitions to only include transitions up to the self.max_levels
excite_lines = excite_lines.loc[
excite_lines.level_number_upper <= self.max_levels
]
deexcite_lines = deexcite_lines.loc[
deexcite_lines.level_number_upper <= self.max_levels
]
### Map the levels to merged levels
excite_lines[
"merged_level_number_lower"
] = excite_lines.level_number_lower.map(self.level_mapping)
excite_lines[
"merged_level_number_upper"
] = excite_lines.level_number_upper.map(self.level_mapping)
deexcite_lines[
"merged_level_number_lower"
] = deexcite_lines.level_number_lower.map(self.level_mapping)
deexcite_lines[
"merged_level_number_upper"
] = deexcite_lines.level_number_upper.map(self.level_mapping)
### Group by level pairs
excite_lines = (
excite_lines.groupby(
["merged_level_number_lower", "merged_level_number_upper"]
)
.agg(
wavelength=("wavelength", "mean"), # Take mean of wavelength
num_electrons=("num_electrons", "sum"), # Take sum of counts
)
.reset_index()
)
deexcite_lines = (
deexcite_lines.groupby(
["merged_level_number_lower", "merged_level_number_upper"]
)
.agg(
wavelength=("wavelength", "mean"), # Take mean of wavelength
num_electrons=("num_electrons", "sum"), # Take sum of counts
)
.reset_index()
)
### Remove the rows where start and end (merged) level is the same
excite_lines = excite_lines.loc[
excite_lines.merged_level_number_lower
!= excite_lines.merged_level_number_upper
]
deexcite_lines = deexcite_lines.loc[
deexcite_lines.merged_level_number_lower
!= deexcite_lines.merged_level_number_upper
]
### Compute default wavelengths if not set by user
if len(excite_lines) + len(deexcite_lines) > 0:
if self.min_wavelength is None: # Compute default wavelength
self._min_wavelength = np.min(
np.concatenate(
(excite_lines.wavelength, deexcite_lines.wavelength)
)
)
if self.max_wavelength is None: # Compute default wavelength
self._max_wavelength = np.max(
np.concatenate(
(excite_lines.wavelength, deexcite_lines.wavelength)
)
)
### Remove the rows outside the wavelength range for the plot
excite_lines = excite_lines.loc[
(excite_lines.wavelength >= self.min_wavelength)
& (excite_lines.wavelength <= self.max_wavelength)
]
deexcite_lines = deexcite_lines.loc[
(deexcite_lines.wavelength >= self.min_wavelength)
& (deexcite_lines.wavelength <= self.max_wavelength)
]
### Compute the standardized log number of electrons for arrow line width
transition_width_coefficient = standardize(
np.concatenate(
(excite_lines.num_electrons, deexcite_lines.num_electrons)
),
transform=self._transition_width_transform,
zero_undefined_offset=1e-3,
)
excite_lines[
"transition_width_coefficient"
] = transition_width_coefficient[: len(excite_lines)]
deexcite_lines[
"transition_width_coefficient"
] = transition_width_coefficient[len(excite_lines) :]
self.excite_lines = excite_lines
self.deexcite_lines = deexcite_lines
def _compute_level_data(self):
"""
Computes the level population data for the horizontal platforms in the widget
"""
### Get energy levels and convert to eV
raw_energy_levels = self._level_energy_data.loc[
self.atomic_number, self.ion_number
].loc[0 : self.max_levels]
### Get level populations
raw_level_populations = self._level_population_data.loc[
self.atomic_number, self.ion_number
].loc[0 : self.max_levels]
### Average out the level populations across all zones, if zone not selected
if self.shell is None:
raw_level_populations = raw_level_populations.mean(axis=1)
else:
raw_level_populations = raw_level_populations[self.shell]
raw_level_populations = pd.Series(
raw_level_populations, name="population"
)
### Join level populations and energy values
raw_level_data = pd.merge(
raw_energy_levels,
raw_level_populations,
left_index=True,
right_index=True,
)
### Merge the levels if energy difference is less than threshold
# Get new level numbers
# TODO: Find a better way to find close levels (less than 0.03 diff in y-coord)
raw_level_data["merged_level_number"] = (
(raw_level_data["energy"] + 1).pct_change().abs()
> self.level_diff_threshold
).cumsum()
# Group data with new level numbers
self.level_data = (
raw_level_data.reset_index()
.groupby("merged_level_number")
.agg(
energy=(
"energy",
"mean",
), # Set energy as mean of merged levels
population=("population", "sum"),
)
) # Add the populations of merged levels
### Standardize the level populations to get width coefficient of levels
self.level_data["level_width_coefficient"] = standardize(
self.level_data.population,
transform=self._level_width_transform,
zero_undefined_offset=1e-3,
)
### Create a mapping from original levels to merged levels
self.level_mapping = raw_level_data.merged_level_number
def _draw_energy_levels(self):
"""
Draws the horizontal energy levels on the widget
"""
# Transform energies and standardize result to get y-coordinate in range [0, 1]
self.level_data["y_coord"] = standardize(
self.level_data.energy,
transform=self._y_coord_transform,
zero_undefined_offset=0.1,
)
### Create the energy levels from level data
for level_number, level_info in self.level_data.iterrows():
# Add the horizontal line
self.fig.add_trace(
go.Scatter(
x=np.linspace(self.x_min - 0.05, self.x_max + 0.05, 10),
y=level_info.y_coord * np.ones(10),
mode="lines",
hovertemplate=f"Energy: {level_info.energy:.2e} eV<br>"
+ f"Population: {level_info.population:.2e}"
+ "<extra></extra>",
line=dict(
color="black",
width=level_info.level_width_coefficient
* self.level_width_scale
+ self.level_width_offset,
)
if level_info.population > 0
else dict(color="grey", dash="dash"),
showlegend=False,
),
row=1,
col=2,
)
# Add label for energy
self.fig.add_annotation(
x=self.x_max + 0.1,
y=level_info.y_coord,
text=f"{level_number}",
showarrow=False,
xref="x2",
yref="y2",
)
def _draw_population_width_scale(self):
"""
Displays the level population width reference bar
"""
### Create width scale
### Find lower and upper bounds of populations and corresponding widths
min_population_idx = self.level_data.population[
self.level_data.population > 0
].idxmin()
max_population_idx = self.level_data.population.idxmax()
min_population = self.level_data.population[min_population_idx]
max_population = self.level_data.population[max_population_idx]
min_width = (
self.level_data.level_width_coefficient[min_population_idx]
* self.level_width_scale
+ self.level_width_offset
)
max_width = (
self.level_data.level_width_coefficient[max_population_idx]
* self.level_width_scale
+ self.level_width_offset
)
### Space the populations (log) and corresponding widths (linear) equally
scale_granularity = 10 # Number of scale ticks to display
population_ticks = self._population_spacer(
min_population, max_population, scale_granularity
)
width_ticks = np.linspace(min_width, max_width, scale_granularity)
y_positions = np.linspace(0, 1, scale_granularity)
### Draw the scale lines
for population, width, y_pos in zip(
population_ticks, width_ticks, y_positions
):
self.fig.add_shape(
type="line",
line_width=width,
x0=0.1,
x1=0.2,
y0=y_pos,
y1=y_pos,
xref="x1",
yref="y1",
)
self.fig.add_annotation(
x=0.35,
y=y_pos,
text=f"{population:.1e}",
showarrow=False,
xref="x1",
yref="y1",
)
# Add title of the width bar
self.fig.add_annotation(
x=0.28,
y=-0.08,
text="Populations",
showarrow=False,
xref="x1",
yref="y1",
)
def _draw_transitions(self, is_excitation):
"""
Draws the transition arrows on the widget
"""
lines = self.excite_lines if is_excitation else self.deexcite_lines
lines["color_coefficient"] = standardize(
lines.wavelength,
transform=self._wavelength_color_transform,
zero_undefined_offset=1e-5,
min_value=self.min_wavelength,
max_value=self.max_wavelength,
)
self._cmap = plt.get_cmap(self.cmapname) # Float to color map
### Plot excitation transitions
for _, line_info in lines.iterrows():
lower, upper = (
line_info.merged_level_number_lower,
line_info.merged_level_number_upper,
)
wavelength, transition_width_coefficient = (
line_info.wavelength,
line_info.transition_width_coefficient,
)
energy_lower, energy_upper = (
self.level_data.loc[lower].energy,
self.level_data.loc[upper].energy,
)
# Get the end x-coordinate (proportional to energy difference between levels)
merged_max_energy_level = self.level_data.energy.max()
x_end = (
(energy_upper - energy_lower)
* (self.x_max - self.x_min)
/ (merged_max_energy_level - energy_lower)
)
# Get the appropriate y-coordinate (computed in _draw_energy_levels)
y_lower = self.level_data.loc[lower].y_coord
y_upper = self.level_data.loc[upper].y_coord
# Get the end arrow color (proportional to log wavelength)
color_coef = line_info.color_coefficient
color = matplotlib.colors.rgb2hex(self._cmap(color_coef)[:3])
# Draw arrow
self.fig.add_trace(
go.Scatter(
x=[self.x_min, x_end],
y=[y_lower, y_upper]
if is_excitation
else [y_upper, y_lower],
hovertemplate=f"Count: {int(line_info.num_electrons)}<br>"
+ f"Wavelength: {wavelength:.2e} {ANGSTROM_SYMBOL}"
+ "<extra></extra>",
marker=dict(
size=self.arrowhead_size,
color=color,
symbol="arrow-up",
angleref="previous",
),
line=dict(
color=color,
width=transition_width_coefficient
* self.transition_width_scale
+ self.transition_width_offset,
),
),
row=1,
col=2,
)
def _draw_transition_width_scale(self):
"""
Displays the transition count width reference bar
"""
### Find lower and upper bounds of num_electrons and corresponding widths
max_num_electrons = np.max(
np.concatenate(
(
self.excite_lines.num_electrons,
self.deexcite_lines.num_electrons,
)
)
)
min_num_electrons = np.min(
np.concatenate(
(
self.excite_lines.num_electrons,
self.deexcite_lines.num_electrons,
)
)
)
max_width_coefficient = np.max(
np.concatenate(
(
self.excite_lines.transition_width_coefficient,
self.deexcite_lines.transition_width_coefficient,
)
)
)
min_width_coefficient = np.min(
np.concatenate(
(
self.excite_lines.transition_width_coefficient,
self.deexcite_lines.transition_width_coefficient,
)
)
)
min_width = (
min_width_coefficient * self.transition_width_scale
+ self.transition_width_offset
)
max_width = (
max_width_coefficient * self.transition_width_scale
+ self.transition_width_offset
)
### Space the num_electrons (log) and corresponding widths (linear) equally
scale_granularity = 10
num_electrons_ticks = self._transition_count_spacer(
min_num_electrons, max_num_electrons, scale_granularity
)
width_ticks = np.linspace(min_width, max_width, scale_granularity)
y_positions = np.linspace(0, 1, scale_granularity)
### Draw the width bar
for num_electrons, width, y_pos in zip(
num_electrons_ticks, width_ticks, y_positions
):
self.fig.add_shape(
type="line",
line_width=width,
x0=0.65,
x1=0.75,
y0=y_pos,
y1=y_pos,
xref="x1",
yref="y1",
)
self.fig.add_annotation(
x=0.9,
y=y_pos,
text=f"{num_electrons:.1e}",
showarrow=False,
xref="x1",
yref="y1",
)
# Add title of the width bar
self.fig.add_annotation(
x=0.83,
y=-0.08,
text="#Packets",
showarrow=False,
xref="x1",
yref="y1",
)
def _draw_transition_color_scale(self):
"""
Displays the transition wavelength colorbar
"""
# Add a dummy Scatter trace to display colorbar
tickvals = self._wavelength_spacer(
self.min_wavelength, self.max_wavelength, 10
)
ticktext = [f"{val:.1e}" for val in tickvals]
self.fig.add_trace(
go.Scatter(
x=[None],
y=[None],
mode="markers",
marker=dict(
colorscale=self.cmapname,
showscale=True,
cmin=self._wavelength_color_transform(self.min_wavelength),
cmax=self._wavelength_color_transform(self.max_wavelength),
colorbar=dict(
title=dict(
text=f"Wavelength ({ANGSTROM_SYMBOL})<br> ",
font_size=12,
),
thickness=5,
tickvals=self._wavelength_color_transform(tickvals),
ticktext=ticktext,
outlinewidth=0,
),
),
hoverinfo="none",
),
row=1,
col=2,
)
[docs]
def display(self):
"""
Function to draw the plot and the reference scales (calls other draw methods independently)
"""
### Create figure and set metadata
self.fig = go.FigureWidget(
make_subplots(
rows=1,
cols=2,
column_width=[0.3, 0.7],
specs=[[{}, {}]],
horizontal_spacing=0.14,
)
)
# Update fig layout
self.fig.update_layout(
title=(
f"Energy Level Diagram for {self.atomic_name} {int_to_roman(self.ion_number + 1)} "
f"(Shell: {self.shell if self.shell is not None else 'All'})"
),
title_x=0.5,
plot_bgcolor="white",
autosize=False,
width=1000,
height=700,
margin=dict(),
showlegend=False,
)
# Remove ticklabels in the reference bars subplot
self.fig.update_yaxes(
showticklabels=False, fixedrange=True, row=1, col=1
)
self.fig.update_xaxes(
showticklabels=False, fixedrange=True, row=1, col=1
)
### Create energy level platforms and width reference scale
self._draw_energy_levels()
self._draw_population_width_scale()
# Remove ticklabels from x-axis
self.fig.update_xaxes(
showticklabels=False, fixedrange=True, row=1, col=2
)
# Update y-ticks to reflect actual energy values
self.fig.update_yaxes(
title=dict(text="Energy (eV)", standoff=5),
range=[0, None],
tickmode="array",
tickvals=self.level_data.y_coord,
ticktext=[f"{energy:.2e}" for energy in self.level_data.energy],
fixedrange=True,
row=1,
col=2,
)
# Add separator between width scales
self.fig.add_shape(
type="line",
line=dict(color="grey", dash="dash"),
line_width=0.5,
x0=0.55,
x1=0.55,
y0=0,
y1=1,
xref="x1",
yref="y1",
)
### Create transition lines and corresponding width and color scales
if len(self.excite_lines) > 0:
self._draw_transitions(is_excitation=True)
if len(self.deexcite_lines) > 0:
self._draw_transitions(is_excitation=False)
if len(self.excite_lines) + len(self.deexcite_lines) > 0:
self._draw_transition_width_scale()
self._draw_transition_color_scale()
return self.fig
[docs]
class GrotrianWidget:
"""
A wrapper class for the Grotrian Diagram, containing the Grotrian Plot and the IpyWidgets
Parameters
----------
plot : tardis.visualization.widgets.grotrian.GrotrianPlot
GrotrianPlot object
num_shells : int
Number of shells in the sim.simulation_state.v_inner
"""
def __init__(self, plot, num_shells, **kwargs):
self.plot = plot
self.num_shells = num_shells
species_list = self._get_species()
self.ion_selector = ipw.Dropdown(
options=species_list,
index=0,
description="Ion",
)
self.plot.set_ion(*species_string_to_tuple(self.ion_selector.value))
self.ion_selector.observe(
self._ion_change_handler,
names="value",
)
self.ion_selector.observe(
self._wavelength_resetter,
names="value",
)
shell_list = ["All"] + [str(i) for i in range(1, num_shells + 1)]
self.shell_selector = ipw.Dropdown(
options=shell_list,
index=0,
description="Shell",
)
self.shell_selector.observe(
lambda change: self._change_handler(
"shell", None if change["new"] == "All" else int(change["new"])
),
names="value",
)
self.shell_selector.observe(
self._wavelength_resetter,
names="value",
)
self.max_level_selector = ipw.BoundedIntText(
value=plot.max_levels,
min=1,
max=40,
step=1,
description="Max Levels",
)
self.max_level_selector.observe(
lambda change: self._change_handler("max_levels", change["new"]),
names="value",
)
self.max_level_selector.observe(
self._wavelength_resetter,
names="value",
)
self.y_scale_selector = ipw.ToggleButtons(
options=GrotrianPlot.Y_SCALE_OPTION.keys(),
index=1,
description="Y-Scale",
layout=ipw.Layout(width="auto"),
style={"button_width": "100px"},
)
self.y_scale_selector.observe(
lambda change: self._change_handler("y_scale", change["new"]),
names="value",
)
self.wavelength_range_selector = ipw.FloatRangeSlider(
value=[self.plot.min_wavelength, self.plot.max_wavelength],
min=self.plot.min_wavelength,
max=self.plot.max_wavelength,
step=0.1,
description="Wavelength",
layout=ipw.Layout(width="605px"),
readout_format=".1e",
)
self.wavelength_range_selector.observe(
self._wavelength_change_handler,
names="value",
)
def _get_species(self):
"""
Computes the ions list for the ion dropdown of the plot
"""
line_interaction_analysis = self.plot._line_interaction_analysis
selected_species_group = line_interaction_analysis[
self.plot.filter_mode
].last_line_in.groupby(["atomic_number", "ion_number"])
if selected_species_group.groups:
selected_species_symbols = [
species_tuple_to_string(item)
for item in selected_species_group.groups.keys()
]
return selected_species_symbols
def _change_handler(self, attribute, value):
"""
Generic function to update the configurable attributes of GrotrianPlot object
Parameters
----------
attribute : str
The name of the attribute of the GrotrianPlot object
value :
The new value of the attribute
"""
index = self.fig.children.index(self.plot.fig)
setattr(self.plot, attribute, value) # Set the value of the attribute
# Set the updated plot in the figure
children_list = list(self.fig.children)
children_list[index] = self.plot.display()
self.fig.children = tuple(children_list)
def _ion_change_handler(self, change):
"""
Function to update ion of GrotrianPlot object
Parameters
----------
change : dict
Change information of the event
"""
atomic_number, ion_number = species_string_to_tuple(change["new"])
index = self.fig.children.index(self.plot.fig)
self.plot.set_ion(atomic_number, ion_number)
# Set the updated plot in the figure
children_list = list(self.fig.children)
children_list[index] = self.plot.display()
self.fig.children = tuple(children_list)
# self._wavelength_resetter()
def _wavelength_change_handler(self, change):
"""
Function to update the wavelength range of GrotrianPlot object
Parameters
----------
change : dict
Change information of the event
"""
min_wavelength, max_wavelength = change["new"]
index = self.fig.children.index(self.plot.fig)
setattr(self.plot, "min_wavelength", min_wavelength)
setattr(self.plot, "max_wavelength", max_wavelength + 1)
# Set the updated plot in the figure
children_list = list(self.fig.children)
children_list[index] = self.plot.display()
self.fig.children = tuple(children_list)
def _wavelength_resetter(self, change):
"""
Resets the range of the wavelength slider whenever the ion, level or shell changes
"""
min_wavelength = self.plot.min_wavelength
max_wavelength = self.plot.max_wavelength
if min_wavelength is None or max_wavelength is None:
self.wavelength_range_selector.layout.visibility = "hidden"
return
elif min_wavelength == max_wavelength:
self.wavelength_range_selector.layout.visibility = "visible"
self.wavelength_range_selector.disabled = True
else:
self.wavelength_range_selector.layout.visibility = "visible"
self.wavelength_range_selector.disabled = False
self.wavelength_range_selector.min = 0.0
self.wavelength_range_selector.max = max_wavelength
self.wavelength_range_selector.min = min_wavelength
self.wavelength_range_selector.value = [
self.wavelength_range_selector.min,
self.wavelength_range_selector.max,
]