import logging
import numpy as np
import pandas as pd
from astropy import units as u
from tardis import constants as const
from tardis.io.model.parse_atom_data import parse_atom_data
from tardis.model import SimulationState
from tardis.opacities.macro_atom.macroatom_solver import MacroAtomSolver
from tardis.opacities.opacity_solver import OpacitySolver
from tardis.plasma.assembly import PlasmaSolverFactory
from tardis.plasma.radiation_field import DilutePlanckianRadiationField
from tardis.simulation.convergence import ConvergenceSolver
from tardis.spectrum.base import SpectrumSolver
from tardis.spectrum.formal_integral import FormalIntegrator
from tardis.spectrum.luminosity import (
calculate_filtered_luminosity,
)
from tardis.transport.montecarlo.base import MonteCarloTransportSolver
from tardis.util.base import is_notebook
from tardis.workflows.workflow_logging import WorkflowLogging
# logging support
logger = logging.getLogger(__name__)
[docs]
class SimpleTARDISWorkflow(WorkflowLogging):
show_progress_bars = is_notebook()
enable_virtual_packet_logging = False
log_level = None
specific_log_level = None
def __init__(self, configuration, csvy=False):
"""A simple TARDIS workflow that runs a simulation to convergence
Parameters
----------
configuration : Configuration
Configuration object for the simulation
csvy : bool, optional
Set true if the configuration uses CSVY, by default False
"""
super().__init__(configuration, self.log_level, self.specific_log_level)
atom_data = parse_atom_data(configuration)
# set up states and solvers
if csvy:
self.simulation_state = SimulationState.from_csvy(
configuration, atom_data=atom_data
)
else:
self.simulation_state = SimulationState.from_config(
configuration,
atom_data=atom_data,
)
plasma_solver_factory = PlasmaSolverFactory(
atom_data,
configuration,
)
plasma_solver_factory.prepare_factory(
self.simulation_state.abundance.index,
"tardis.plasma.properties.property_collections",
configuration,
)
self.plasma_solver = plasma_solver_factory.assemble(
self.simulation_state.elemental_number_density,
self.simulation_state.radiation_field_state,
self.simulation_state.time_explosion,
self.simulation_state._electron_densities,
)
line_interaction_type = configuration.plasma.line_interaction_type
self.opacity_solver = OpacitySolver(
line_interaction_type,
configuration.plasma.disable_line_scattering,
)
if line_interaction_type == "scatter":
self.macro_atom_solver = None
else:
self.macro_atom_solver = MacroAtomSolver()
self.transport_state = None
self.transport_solver = MonteCarloTransportSolver.from_config(
configuration,
packet_source=self.simulation_state.packet_source,
enable_virtual_packet_logging=self.enable_virtual_packet_logging,
)
# Luminosity filter frequencies
self.luminosity_nu_start = (
configuration.supernova.luminosity_wavelength_end.to(
u.Hz, u.spectral()
)
)
if u.isclose(
configuration.supernova.luminosity_wavelength_start, 0 * u.angstrom
):
self.luminosity_nu_end = np.inf * u.Hz
else:
self.luminosity_nu_end = (
const.c / configuration.supernova.luminosity_wavelength_start
).to(u.Hz)
# montecarlo settings
self.total_iterations = int(configuration.montecarlo.iterations)
self.real_packet_count = int(configuration.montecarlo.no_of_packets)
final_iteration_packet_count = (
configuration.montecarlo.last_no_of_packets
)
if (
final_iteration_packet_count is None
or final_iteration_packet_count < 0
):
final_iteration_packet_count = self.real_packet_count
self.final_iteration_packet_count = int(final_iteration_packet_count)
self.virtual_packet_count = int(
configuration.montecarlo.no_of_virtual_packets
)
# spectrum settings
self.integrated_spectrum_settings = configuration.spectrum.integrated
self.spectrum_solver = SpectrumSolver.from_config(configuration)
# Convergence settings
self.consecutive_converges_count = 0
self.converged = False
self.completed_iterations = 0
self.luminosity_requested = (
configuration.supernova.luminosity_requested.cgs
)
# Convergence solvers
self.convergence_strategy = (
configuration.montecarlo.convergence_strategy
)
self.convergence_solvers = {}
self.convergence_solvers["t_radiative"] = ConvergenceSolver(
self.convergence_strategy.t_rad
)
self.convergence_solvers["dilution_factor"] = ConvergenceSolver(
self.convergence_strategy.w
)
self.convergence_solvers["t_inner"] = ConvergenceSolver(
self.convergence_strategy.t_inner
)
[docs]
def get_convergence_estimates(self):
"""Compute convergence estimates from the transport state
Returns
-------
dict
Convergence estimates
EstimatedRadiationFieldProperties
Dilute radiation file and j_blues dataclass
"""
estimated_radfield_properties = (
self.transport_solver.radfield_prop_solver.solve(
self.transport_state.radfield_mc_estimators,
self.transport_state.time_explosion,
self.transport_state.time_of_simulation,
self.transport_state.geometry_state.volume,
self.transport_state.opacity_state.line_list_nu,
)
)
estimated_t_radiative = estimated_radfield_properties.dilute_blackbody_radiationfield_state.temperature
estimated_dilution_factor = estimated_radfield_properties.dilute_blackbody_radiationfield_state.dilution_factor
emitted_luminosity = calculate_filtered_luminosity(
self.transport_state.emitted_packet_nu,
self.transport_state.emitted_packet_luminosity,
self.luminosity_nu_start,
self.luminosity_nu_end,
)
luminosity_ratios = (
(emitted_luminosity / self.luminosity_requested).to(1).value
)
estimated_t_inner = (
self.simulation_state.t_inner
* luminosity_ratios
** self.convergence_strategy.t_inner_update_exponent
)
return {
"t_radiative": estimated_t_radiative,
"dilution_factor": estimated_dilution_factor,
"t_inner": estimated_t_inner,
}, estimated_radfield_properties
[docs]
def check_convergence(
self,
estimated_values,
):
"""Check convergence status for a dict of estimated values
Parameters
----------
estimated_values : dict
Estimates to check convergence
Returns
-------
bool
If convergence has occurred
"""
convergence_statuses = []
for key, solver in self.convergence_solvers.items():
current_value = getattr(self.simulation_state, key)
estimated_value = estimated_values[key]
no_of_shells = (
self.simulation_state.no_of_shells if key != "t_inner" else 1
)
convergence_statuses.append(
solver.get_convergence_status(
current_value, estimated_value, no_of_shells
)
)
if np.all(convergence_statuses):
hold_iterations = self.convergence_strategy.hold_iterations
self.consecutive_converges_count += 1
logger.info(
f"Iteration converged {self.consecutive_converges_count:d}/{(hold_iterations + 1):d} consecutive "
f"times."
)
return self.consecutive_converges_count >= hold_iterations + 1
self.consecutive_converges_count = 0
return False
[docs]
def solve_simulation_state(
self,
estimated_values,
):
"""Update the simulation state with new inputs computed from previous
iteration estimates.
Parameters
----------
estimated_values : dict
Estimated from the previous iterations
"""
next_values = {}
for key, solver in self.convergence_solvers.items():
if (
key == "t_inner"
and (self.completed_iterations + 1)
% self.convergence_strategy.lock_t_inner_cycles
!= 0
):
next_values[key] = getattr(self.simulation_state, key)
else:
next_values[key] = solver.converge(
getattr(self.simulation_state, key), estimated_values[key]
)
self.simulation_state.t_radiative = next_values["t_radiative"]
self.simulation_state.dilution_factor = next_values["dilution_factor"]
self.simulation_state.blackbody_packet_source.temperature = next_values[
"t_inner"
]
return next_values
[docs]
def solve_plasma(self, estimated_radfield_properties):
"""Update the plasma solution with the new radiation field estimates
Parameters
----------
estimated_radfield_properties : EstimatedRadiationFieldProperties
The radiation field properties to use for updating the plasma
Raises
------
ValueError
If the plasma solver radiative rates type is unknown
"""
radiation_field = DilutePlanckianRadiationField(
temperature=self.simulation_state.t_radiative,
dilution_factor=self.simulation_state.dilution_factor,
)
update_properties = dict(
dilute_planckian_radiation_field=radiation_field
)
# A check to see if the plasma is set with JBluesDetailed, in which
# case it needs some extra kwargs.
if (
self.plasma_solver.plasma_solver_settings.RADIATIVE_RATES_TYPE
== "blackbody"
):
planckian_radiation_field = (
radiation_field.to_planckian_radiation_field()
)
j_blues = planckian_radiation_field.calculate_mean_intensity(
self.plasma_solver.atomic_data.lines.nu.values
)
update_properties["j_blues"] = pd.DataFrame(
j_blues, index=self.plasma_solver.atomic_data.lines.index
)
elif (
self.plasma_solver.plasma_solver_settings.RADIATIVE_RATES_TYPE
== "dilute-blackbody"
):
j_blues = radiation_field.calculate_mean_intensity(
self.plasma_solver.atomic_data.lines.nu.values
)
update_properties["j_blues"] = pd.DataFrame(
j_blues, index=self.plasma_solver.atomic_data.lines.index
)
elif (
self.plasma_solver.plasma_solver_settings.RADIATIVE_RATES_TYPE
== "detailed"
):
update_properties["j_blues"] = pd.DataFrame(
estimated_radfield_properties.j_blues,
index=self.plasma_solver.atomic_data.lines.index,
)
else:
raise ValueError(
f"radiative_rates_type type unknown - {self.plasma.plasma_solver_settings.RADIATIVE_RATES_TYPE}"
)
self.plasma_solver.update(**update_properties)
[docs]
def solve_opacity(self):
"""Solves the opacity state and any associated objects
Returns
-------
dict
opacity_state : tardis.opacities.opacity_state.OpacityState
State of the line opacities
macro_atom_state : tardis.opacities.macro_atom.macro_atom_state.MacroAtomState or None
State of the macro atom
"""
opacity_state = self.opacity_solver.solve(self.plasma_solver)
if self.macro_atom_solver is None:
macro_atom_state = None
else:
macro_atom_state = self.macro_atom_solver.solve(
self.plasma_solver.j_blues,
self.plasma_solver.atomic_data,
opacity_state.tau_sobolev,
self.plasma_solver.stimulated_emission_factor,
opacity_state.beta_sobolev,
)
return {
"opacity_state": opacity_state,
"macro_atom_state": macro_atom_state,
}
[docs]
def solve_montecarlo(
self, opacity_states, no_of_real_packets, no_of_virtual_packets=0
):
"""Solve the MonteCarlo process
Parameters
----------
opacity_states : dict
Opacity and (optionally) Macro Atom states.
no_of_real_packets : int
Number of real packets to simulate
no_of_virtual_packets : int, optional
Number of virtual packets to simulate per interaction, by default 0
Returns
-------
MonteCarloTransportState
The new transport state after simulation
ndarray
Array of unnormalized virtual packet energies in each frequency bin
"""
opacity_state = opacity_states["opacity_state"]
macro_atom_state = opacity_states["macro_atom_state"]
self.transport_state = self.transport_solver.initialize_transport_state(
self.simulation_state,
opacity_state,
macro_atom_state,
self.plasma_solver,
no_of_real_packets,
no_of_virtual_packets=no_of_virtual_packets,
iteration=self.completed_iterations,
)
virtual_packet_energies = self.transport_solver.run(
self.transport_state,
iteration=self.completed_iterations,
total_iterations=self.total_iterations,
show_progress_bars=self.show_progress_bars,
)
output_energy = self.transport_state.packet_collection.output_energies
if np.sum(output_energy < 0) == len(output_energy):
logger.critical("No r-packet escaped through the outer boundary.")
return virtual_packet_energies
[docs]
def initialize_spectrum_solver(
self,
opacity_states,
virtual_packet_energies=None,
):
"""Set up the spectrum solver
Parameters
----------
virtual_packet_energies : ndarray, optional
Array of virtual packet energies binned by frequency, by default None
"""
# Set up spectrum solver
self.spectrum_solver.transport_state = self.transport_state
if virtual_packet_energies is not None:
self.spectrum_solver._montecarlo_virtual_luminosity.value[:] = (
virtual_packet_energies
)
if self.integrated_spectrum_settings is not None:
# Set up spectrum solver integrator
self.spectrum_solver.integrator_settings = (
self.integrated_spectrum_settings
)
self.spectrum_solver._integrator = FormalIntegrator(
self.simulation_state,
self.plasma_solver,
self.transport_solver,
opacity_states["opacity_state"],
opacity_states["macro_atom_state"],
)
[docs]
def run(self):
"""Run the TARDIS simulation until convergence is reached"""
self.converged = False
while self.completed_iterations < self.total_iterations - 1:
logger.info(
f"\n\tStarting iteration {(self.completed_iterations + 1):d} of {self.total_iterations:d}"
)
opacity_states = self.solve_opacity()
virtual_packet_energies = self.solve_montecarlo(
opacity_states, self.real_packet_count
)
(
estimated_values,
estimated_radfield_properties,
) = self.get_convergence_estimates()
self.solve_simulation_state(estimated_values)
self.solve_plasma(estimated_radfield_properties)
self.converged = self.check_convergence(estimated_values)
self.completed_iterations += 1
if self.converged and self.convergence_strategy.stop_if_converged:
break
if self.converged:
logger.info("\n\tStarting final iteration")
else:
logger.error(
"\n\tITERATIONS HAVE NOT CONVERGED, starting final iteration"
)
virtual_packet_energies = self.solve_montecarlo(
opacity_states,
self.final_iteration_packet_count,
self.virtual_packet_count,
)
self.initialize_spectrum_solver(
opacity_states,
virtual_packet_energies,
)