"""Convergence Plots to see the convergence of the simulation in real time."""
from collections import defaultdict
import matplotlib.cm as cm
import matplotlib.colors as clr
import numpy as np
import plotly.graph_objects as go
from IPython.display import display
import matplotlib as mpl
import ipywidgets as widgets
from contextlib import suppress
from traitlets import TraitError
from astropy import units as u
[docs]
def transition_colors(length, name="jet"):
"""
Create colorscale for convergence plots, returns a list of colors.
Parameters
----------
length : int
The length of the colorscale.
name : string, default: 'jet', optional
Name of the colorscale.
Returns
-------
colors: list
"""
cmap = mpl.pyplot.get_cmap(name, length)
colors = []
for i in range(cmap.N):
rgb = cmap(i)[:3]
colors.append(mpl.colors.rgb2hex(rgb))
return colors
[docs]
class ConvergencePlots(object):
"""
Create and update convergence plots for visualizing convergence of the simulation.
Parameters
----------
iterations : int
iteration number
**kwargs : dict, optional
Additional keyword arguments. These arguments are defined in the Other Parameters section.
Other Parameters
----------------
plasma_plot_config : dict, optional
Dictionary used to override default plot properties of plasma plots.
t_inner_luminosities_config : dict, optional
Dictionary used to override default plot properties of the inner boundary temperature and luminosity plots.
plasma_cmap : str, default: 'jet', optional
String defining the cmap used in plasma plots.
t_inner_luminosities_colors : str or list, optional
String defining cmap for luminosity and inner boundary temperature plot.
The list can be a list of colors in rgb, hex or css-names format as well.
export_convergence_plots : bool, default: False, optional
If True, plots are displayed again using the `notebook_connected` renderer. This helps
to display the plots in the documentation or in platforms like nbviewer.
Notes
-----
When overriding plot's configuration using the `plasma_plot_config` and the
`t_inner_luminosities_config` dictionaries, data related properties are
applied equally accross all traces.
The dictionary should have a structure like that of `plotly.graph_objs.FigureWidget.to_dict()`,
for more information please see https://plotly.com/python/figure-structure/
"""
def __init__(self, iterations, **kwargs):
self.iterable_data = {}
self.value_data = defaultdict(list)
self.iterations = iterations
self.current_iteration = 1
self.luminosities = ["Emitted", "Absorbed", "Requested"]
self.plasma_plot = None
self.t_inner_luminosities_plot = None
if "plasma_plot_config" in kwargs:
self.plasma_plot_config = kwargs["plasma_plot_config"]
if "t_inner_luminosities_config" in kwargs:
self.t_inner_luminosities_config = kwargs[
"t_inner_luminosities_config"
]
if "plasma_cmap" in kwargs:
self.plasma_colorscale = transition_colors(
name=kwargs["plasma_cmap"], length=self.iterations
)
else:
# default color scale is jet
self.plasma_colorscale = transition_colors(length=self.iterations)
if "t_inner_luminosities_colors" in kwargs:
# use cmap if string
if type(kwargs["t_inner_luminosities_colors"]) == str:
self.t_inner_luminosities_colors = transition_colors(
length=5,
name=kwargs["t_inner_luminosities_colors"],
)
else:
self.t_inner_luminosities_colors = kwargs[
"t_inner_luminosities_colors"
]
else:
# using default plotly colors
self.t_inner_luminosities_colors = [None] * 5
[docs]
def fetch_data(self, name=None, value=None, item_type=None):
"""
Fetch data from the Simulation class.
Parameters
----------
name : string
name of the data
value : string or array
string or an array of quantities
item_type : string
either iterable or value
"""
# trace data for plasma plots is added in iterable data dictionary
if item_type == "iterable":
self.iterable_data[name] = value
# trace data for luminosity plots and inner boundary temperature plot is stored in value_data dictionary
if item_type == "value":
self.value_data[name].append(value)
[docs]
def create_plasma_plot(self):
"""Create an empty plasma plot."""
fig = go.FigureWidget().set_subplots(rows=1, cols=2, shared_xaxes=True)
# empty traces to build figure
fig.add_scatter(row=1, col=1)
fig.add_scatter(row=1, col=2)
# 2 y axes and 2 x axes correspond to the 2 subplots in the plasma plot
fig = fig.update_layout(
xaxis={
"tickformat": "g",
"title": r"$\text{Velocity}~[\text{km}~\text{s}^{-1}]$",
},
xaxis2={
"tickformat": "g",
"title": r"$\text{Velocity}~[\text{km}~\text{s}^{-1}]$",
"matches": "x",
},
yaxis={
"tickformat": "g",
"title": r"$T_{\text{rad}}\ [\text{K}]$",
"nticks": 15,
},
yaxis2={
"tickformat": "g",
"title": r"$W$",
"nticks": 15,
},
height=450,
legend_title_text="Iterations",
legend_traceorder="reversed",
margin=dict(
l=10, r=135, b=25, t=25, pad=0
), # reduce whitespace surrounding the plot and increase right indentation to align with the t_inner and luminosity plot
)
# allow overriding default layout
if hasattr(self, "plasma_plot_config"):
self.override_plot_parameters(fig, self.plasma_plot_config)
self.plasma_plot = fig
[docs]
def create_t_inner_luminosities_plot(self):
"""Create an empty t_inner and luminosity plot."""
fig = go.FigureWidget().set_subplots(
rows=3,
cols=1,
shared_xaxes=True,
vertical_spacing=0.08,
row_heights=[0.25, 0.5, 0.25],
)
# add inner boundary temperature vs iterations plot
fig.add_scatter(
name="Inner<br>Boundary<br>Temperature",
row=1,
col=1,
hovertext="text",
marker_color=self.t_inner_luminosities_colors[0],
mode="lines",
)
# add luminosity vs iterations plot
# has three traces for emitted, requested and absorbed luminosities
for luminosity, line_color in zip(
self.luminosities, self.t_inner_luminosities_colors[1:4]
):
fig.add_scatter(
name=luminosity + "<br>Luminosity",
mode="lines",
row=2,
col=1,
marker_color=line_color,
)
# add residual luminosity vs iterations plot
fig.add_scatter(
name="Residual<br>Luminosity",
row=3,
col=1,
marker_color=self.t_inner_luminosities_colors[4],
mode="lines",
)
# 3 y axes and 3 x axes correspond to the 3 subplots in the t_inner and luminosity convergence plot
fig = fig.update_layout(
xaxis=dict(range=[0, self.iterations + 1], dtick=2),
xaxis2=dict(
matches="x",
range=[0, self.iterations + 1],
dtick=2,
),
xaxis3=dict(
title=r"$\mbox{Iteration Number}$",
dtick=2,
),
yaxis=dict(
title=r"$T_{\text{inner}}\ [\text{K}]$",
automargin=True,
tickformat="g",
exponentformat="e",
nticks=4,
),
yaxis2=dict(
exponentformat="e",
title=r"$\text{Luminosity}~[\text{erg s}^{-1}]$",
title_font_size=13,
automargin=True,
nticks=7,
),
yaxis3=dict(
title=r"$~~\text{Residual}\\\text{Luminosity[%]}$",
title_font_size=12,
automargin=True,
nticks=4,
),
height=630,
hoverlabel_align="right",
margin=dict(
b=25, t=25, pad=0
), # reduces whitespace surrounding the plot
)
# allow overriding default layout
if hasattr(self, "t_inner_luminosities_config"):
self.override_plot_parameters(fig, self.t_inner_luminosities_config)
self.t_inner_luminosities_plot = fig
[docs]
def override_plot_parameters(self, fig, parameters):
"""
Override default plot properties.
Any property inside the data dictionary is however, applied equally across all traces.
This means trace-specific data properties can't be changed using this function.
Parameters
----------
fig : go.FigureWidget
FigureWidget object to be updated
parameters : dict
Dictionary used to update the default plot style.
"""
# because fig.data is a tuple of traces, a property in the data dictionary is applied to all traces
# the fig is a nested dictionary, any property n levels deep is not changed until the value is a not dictionary
# fig["property_1"]["property_2"]...["property_n"] = "value"
for key, value in parameters.items():
if key == "data":
# all traces will have same data property
for trace in list(fig.data):
self.override_plot_parameters(trace, value)
else:
if type(value) == dict:
self.override_plot_parameters(fig[key], value)
else:
fig[key] = value
[docs]
def build(self, display_plot=True):
"""
Create empty convergence plots and display them.
Parameters
----------
display_plot : bool, default: True, optional
Displays empty plots.
"""
self.create_plasma_plot()
self.create_t_inner_luminosities_plot()
if display_plot:
display(
widgets.VBox(
[self.plasma_plot, self.t_inner_luminosities_plot],
)
)
[docs]
def update_plasma_plots(self):
"""Update plasma convergence plots every iteration."""
# convert velocity to km/s
velocity_km_s = (
self.iterable_data["velocity"].to(u.km / u.s).value.tolist()
)
# add luminosity data in hover data in plasma plots
customdata = len(velocity_km_s) * [
"<br>"
+ "Emitted Luminosity: "
+ f'{self.value_data["Emitted"][-1]:.4g}'
+ "<br>"
+ "Requested Luminosity: "
+ f'{self.value_data["Requested"][-1]:.4g}'
+ "<br>"
+ "Absorbed Luminosity: "
+ f'{self.value_data["Absorbed"][-1]:.4g}'
]
# add a radiation temperature vs shell velocity trace to the plasma plot
self.plasma_plot.add_scatter(
x=velocity_km_s,
y=np.append(
self.iterable_data["t_rad"], self.iterable_data["t_rad"][-1:]
),
line_color=self.plasma_colorscale[self.current_iteration - 1],
line_shape="hv",
row=1,
col=1,
name=self.current_iteration,
legendgroup=f"group-{self.current_iteration}",
showlegend=False,
customdata=customdata,
hovertemplate="<b>Y</b>: %{y:.3f} at <b>X</b> = %{x:,.0f}%{customdata}",
)
# add a dilution factor vs shell velocity trace to the plasma plot
self.plasma_plot.add_scatter(
x=velocity_km_s,
y=np.append(self.iterable_data["w"], self.iterable_data["w"][-1:]),
line_color=self.plasma_colorscale[self.current_iteration - 1],
line_shape="hv",
row=1,
col=2,
legendgroup=f"group-{self.current_iteration}",
name=self.current_iteration,
customdata=customdata,
hovertemplate="<b>Y</b>: %{y:.3f} at <b>X</b> = %{x:,.0f}%{customdata}",
)
[docs]
def update_t_inner_luminosities_plot(self):
"""Update the t_inner and luminosity convergence plots every iteration."""
x = list(range(1, self.iterations + 1))
with self.t_inner_luminosities_plot.batch_update():
# traces are updated according to the order they were added
# the first trace is of the inner boundary temperature plot
self.t_inner_luminosities_plot.data[0].x = x
self.t_inner_luminosities_plot.data[0].y = self.value_data[
"t_inner"
]
self.t_inner_luminosities_plot.data[
0
].hovertemplate = "<b>%{y:.3f}</b> at X = %{x:,.0f}<extra>Inner Boundary Temperature</extra>" # trace name in extra tag to avoid new lines in hoverdata
# the next three for emitted, absorbed and requested luminosities
for index, luminosity in zip(range(1, 4), self.luminosities):
self.t_inner_luminosities_plot.data[index].x = x
self.t_inner_luminosities_plot.data[index].y = self.value_data[
luminosity
]
self.t_inner_luminosities_plot.data[index].hovertemplate = (
"<b>%{y:.4g}</b>" + "<br>at X = %{x}<br>"
)
# last is for the residual luminosity
y = [
((emitted - requested) * 100) / requested
for emitted, requested in zip(
self.value_data["Emitted"], self.value_data["Requested"]
)
]
self.t_inner_luminosities_plot.data[4].x = x
self.t_inner_luminosities_plot.data[4].y = y
self.t_inner_luminosities_plot.data[
4
].hovertemplate = "<b>%{y:.2f}%</b> at X = %{x:,.0f}"
[docs]
def update(self, export_convergence_plots=False, last=False):
"""
Update the convergence plots every iteration.
Parameters
----------
export_convergence_plots : bool, default: False, optional
Displays the convergence plots again using plotly's `notebook_connected` renderer.
This helps to display the plots in notebooks when shared on platforms like nbviewer.
Please see https://plotly.com/python/renderers/ for more information.
last : bool, default: False, optional
True if it's last iteration.
"""
if self.iterable_data != {}:
# build only at first iteration
if self.current_iteration == 1:
self.build()
self.update_plasma_plots()
self.update_t_inner_luminosities_plot()
# data property for plasma plots needs to be
# updated after the last iteration because new traces have been added
if hasattr(self, "plasma_plot_config") and last:
if "data" in self.plasma_plot_config:
self.override_plot_parameters(
self.plasma_plot, self.plasma_plot_config
)
self.current_iteration += 1
# the display function expects a Widget, while
# fig.show() returns None, which causes the TraitError.
if export_convergence_plots and (self.plasma_plot is not None):
with suppress(TraitError):
display(
widgets.VBox(
[
self.plasma_plot.show(
renderer="notebook_connected"
),
self.t_inner_luminosities_plot.show(
renderer="notebook_connected"
),
]
)
)