You can interact with this notebook online: Launch notebook
Using Custom Callbacks When Running TARDIS¶
The function run_tardis
allows users to provide a set of callbacks to the simulation. These callbacks are functions that will be run at the end of each iteration, and can do a variety of things, such as printing information about the simulation, storing data to a table, or even changing simulation parameters between iterations. This tutorial will show three examples of callbacks and how they can be used in TARDIS. One important thing to note is that the first argument of the callback must be
the Simulation
object being run.
Our first callback example will compute the (volume-weighted) average radiative temperature in the supernova ejecta (outside of the photosphere) and will print its value:
[1]:
def average_temp(sim):
t_rads = sim.simulation_state.t_radiative
volumes = sim.simulation_state.volume
avg = sum(t_rads*volumes) / sum(volumes)
print(f"Average temperature for iteration {sim.iterations_executed}: {avg}")
Now we give the callback to run_tardis
. run_tardis
offers the keyword argument simulation_callbacks
which takes a list of lists containing the callback as well as any optional arguments you wish to include with your callback. For this example our function requires no extra arguments and we only have a single callback, so we give run_tardis
a 2D list containing the callback as its only element:
[2]:
# We filter out warnings in this notebook
import warnings
warnings.filterwarnings('ignore')
from tardis import run_tardis
from tardis.io.atom_data import download_atom_data
# We download the atomic data needed to run the simulation
download_atom_data('kurucz_cd23_chianti_H_He')
# We run the simulation with our callback
sim = run_tardis('tardis_example.yml',
simulation_callbacks=[[average_temp]])
Atomic Data kurucz_cd23_chianti_H_He already exists in /home/runner/Downloads/tardis-data/kurucz_cd23_chianti_H_He.h5. Will not download - override with force_download=True.
Average temperature for iteration 1: 10002.345335990905 K
Average temperature for iteration 2: 10636.306724631839 K
Average temperature for iteration 3: 10781.451660994117 K
Average temperature for iteration 4: 10847.470538493748 K
Average temperature for iteration 5: 10852.893010414327 K
Average temperature for iteration 6: 10855.104475524438 K
Average temperature for iteration 7: 10872.983731448345 K
Average temperature for iteration 8: 10900.252453404333 K
Average temperature for iteration 9: 10928.447457469569 K
Average temperature for iteration 10: 10887.139946956373 K
Average temperature for iteration 11: 10808.256295066209 K
Average temperature for iteration 12: 10916.035502125074 K
Average temperature for iteration 13: 10890.679856236407 K
Average temperature for iteration 14: 10839.310459658938 K
Average temperature for iteration 15: 10835.970184787593 K
Average temperature for iteration 16: 10898.815372772322 K
Average temperature for iteration 17: 10993.425063230141 K
Average temperature for iteration 18: 10923.229603253854 K
Average temperature for iteration 19: 10915.386068997337 K
Average temperature for iteration 20: 10915.386068997337 K
Running Callbacks with Extra Arguments¶
The callbacks provided to run_tardis
can also take extra arguments. As an example, we’ll make a callback that appends the number of monte carlo packets emitted by the supernova for each iteration to a list so we can plot the number of emitted packets for each iteration. We will also specify that we want this information for all but the last iteration, as more packets are used in the last iteration than are used in the other iterations. The callback will take a list we want to append to as an
argument. We’ll send both this new callback and our original average_temp
callback to run_tardis
as an example of using multiple callbacks at once:
[3]:
def append_num_emitted_to_list(sim, lst):
if sim.iterations_executed < sim.iterations:
num_emitted_packets = len(sim.transport.transport_state.emitted_packet_nu)
lst.append(num_emitted_packets)
In order to add our new callback, we just create another entry in our list of callbacks. Since append_num_emitted_to_list
takes an extra argument, we will provide that argument in the inner list containing the callback:
[4]:
# Initialize a list to store the number of emitted packets
num_emitted_list = []
# Make our list of callbacks
callbacks = [[average_temp],
[append_num_emitted_to_list, num_emitted_list]]
# Run the simulation with both of our callbacks
sim = run_tardis('tardis_example.yml',
simulation_callbacks=callbacks)
Average temperature for iteration 1: 10002.345335990905 K
Average temperature for iteration 2: 10636.306724631839 K
Average temperature for iteration 3: 10781.451660994117 K
Average temperature for iteration 4: 10847.470538493748 K
Average temperature for iteration 5: 10852.893010414327 K
Average temperature for iteration 6: 10855.104475524438 K
Average temperature for iteration 7: 10872.983731448345 K
Average temperature for iteration 8: 10900.252453404333 K
Average temperature for iteration 9: 10928.447457469569 K
Average temperature for iteration 10: 10887.139946956373 K
Average temperature for iteration 11: 10808.256295066209 K
Average temperature for iteration 12: 10916.035502125074 K
Average temperature for iteration 13: 10890.679856236407 K
Average temperature for iteration 14: 10839.310459658938 K
Average temperature for iteration 15: 10835.970184787593 K
Average temperature for iteration 16: 10898.815372772322 K
Average temperature for iteration 17: 10993.425063230141 K
Average temperature for iteration 18: 10923.229603253854 K
Average temperature for iteration 19: 10915.386068997337 K
Average temperature for iteration 20: 10915.386068997337 K
Now we can look at how many packets are emitted after each iteration:
[5]:
import matplotlib.pyplot as plt
# Generate a list of each iteration number for the x-axis
iterations = list(range(1, len(num_emitted_list)+1))
# Plot the number of emitted packets
plt.plot(iterations, num_emitted_list)
plt.xlabel("Iteration")
plt.ylabel("Number of emitted packets")
[5]:
Text(0, 0.5, 'Number of emitted packets')
Using Callbacks to Add New Functionality¶
Callbacks can also add new functionality to the code. For example, we introduce one final callback inc_packets
that will increase the number of packets in the following iteration by a number \(N\) (which is an argument to the callback, in our example we shall use \(N=1000\)):
[6]:
def inc_packets(sim, N):
sim.no_of_packets += N
[7]:
# Initialize a new list to store the number of emitted packets
num_emitted_list_new = []
# Make our new list of callbacks
callbacks = [[average_temp],
[append_num_emitted_to_list, num_emitted_list_new],
[inc_packets, 1000]]
# Run the simulation with all three of our callbacks
sim = run_tardis('tardis_example.yml',
simulation_callbacks=callbacks)
Average temperature for iteration 1: 10002.345335990905 K
Average temperature for iteration 2: 10613.908754846558 K
Average temperature for iteration 3: 10850.348012607232 K
Average temperature for iteration 4: 10887.571829540853 K
Average temperature for iteration 5: 10830.30318618678 K
Average temperature for iteration 6: 10858.234351239738 K
Average temperature for iteration 7: 10792.481079535482 K
Average temperature for iteration 8: 10838.693927807315 K
Average temperature for iteration 9: 10947.325467890487 K
Average temperature for iteration 10: 10904.884829767207 K
Average temperature for iteration 11: 10909.70977691611 K
Average temperature for iteration 12: 10936.034767331808 K
Average temperature for iteration 13: 10914.993081485425 K
Average temperature for iteration 14: 10926.52320767812 K
Average temperature for iteration 15: 10919.420818412638 K
Average temperature for iteration 16: 10898.993427208812 K
Average temperature for iteration 17: 10868.841791849643 K
Average temperature for iteration 18: 10868.404619085652 K
Average temperature for iteration 19: 10935.815023903202 K
Average temperature for iteration 20: 10935.815023903202 K
Now, let’s see how this affected our plot for packets emitted in each iteration:
[8]:
plt.plot(iterations, num_emitted_list_new)
plt.xlabel("Iteration")
plt.ylabel("Number of emitted packets")
[8]:
Text(0, 0.5, 'Number of emitted packets')
As expected, the number of packets emitted will keep on increasing as 1000 more packets are run each iteration.