Source code for tardis.transport.montecarlo.modes.classic.solver

import logging

from astropy import units as u
from numba import cuda, set_num_threads

import tardis.transport.montecarlo.configuration.constants as constants
from tardis import constants as const
from tardis.io.hdf_writer_mixin import HDFWriterMixin
from tardis.io.logger import montecarlo_tracking as mc_tracker
from tardis.transport.montecarlo.configuration import montecarlo_globals
from tardis.transport.montecarlo.configuration.base import (
    MonteCarloConfiguration,
    configuration_initialize,
)
from tardis.transport.montecarlo.estimators.mc_rad_field_solver import (
    MCRadiationFieldPropertiesSolver,
)
from tardis.transport.montecarlo.modes.classic.montecarlo_transport import (
    montecarlo_transport,
)
from tardis.transport.montecarlo.montecarlo_transport_state import (
    MonteCarloTransportState,
)
from tardis.transport.montecarlo.packets.trackers.tracker_full_util import (
    generate_tracker_full_list,
    tracker_full_df2tracker_last_interaction_df,
    trackers_full_to_df,
)
from tardis.transport.montecarlo.packets.trackers.tracker_last_interaction_util import (
    generate_tracker_last_interaction_list,
    trackers_last_interaction_to_df,
)
from tardis.transport.montecarlo.progress_bars import (
    refresh_packet_pbar,
    reset_packet_pbar,
    update_iterations_pbar,
)
from tardis.util.base import (
    quantity_linspace,
)

logger = logging.getLogger(__name__)


# TODO: refactor this into more parts
[docs] class MCTransportSolverClassic(HDFWriterMixin): """ This class modifies the MonteCarloTransportState to solve the radiative transfer problem. """ hdf_properties = ["transport_state"] hdf_name = "transport" def __init__( self, radfield_prop_solver, spectrum_frequency_grid, virtual_spectrum_spawn_range, enable_full_relativity, line_interaction_type, spectrum_method, packet_source, enable_virtual_packet_logging=False, enable_rpacket_tracking=False, nthreads=1, debug_packets=False, logger_buffer=1, use_gpu=False, montecarlo_configuration=None, ): self.radfield_prop_solver = radfield_prop_solver # inject different packets self.spectrum_frequency_grid = spectrum_frequency_grid self.virtual_spectrum_spawn_range = virtual_spectrum_spawn_range self.enable_full_relativity = enable_full_relativity self.line_interaction_type = line_interaction_type self.spectrum_method = spectrum_method self.use_gpu = use_gpu self.enable_vpacket_tracking = enable_virtual_packet_logging self.enable_rpacket_tracking = enable_rpacket_tracking self.montecarlo_configuration = montecarlo_configuration self.packet_source = packet_source # Setting up the Tracking array for storing all the RPacketTracker instances self.rpacket_tracker = None # Set number of threads self.nthreads = nthreads # set up logger based on config mc_tracker.DEBUG_MODE = debug_packets mc_tracker.BUFFER = logger_buffer
[docs] def initialize_transport_state( self, simulation_state, opacity_state, macro_atom_state, plasma, no_of_packets, no_of_virtual_packets=0, iteration=0, ): if not plasma.continuum_interaction_species.empty: if plasma.gamma is not None: n_levels_bf_species_by_n_cells_tuple = plasma.gamma.shape else: n_levels_bf_species_by_n_cells_tuple = plasma.phi_lucy.shape else: n_levels_bf_species_by_n_cells_tuple = (0, 0) packet_collection = self.packet_source.create_packets( no_of_packets, seed_offset=iteration ) # Classic mode: continuum processes disabled montecarlo_globals.CONTINUUM_PROCESSES_ENABLED = False geometry_state = simulation_state.geometry.to_numba() opacity_state_numba = opacity_state.to_numba( macro_atom_state, self.line_interaction_type, ) opacity_state_numba = opacity_state_numba[ simulation_state.geometry.v_inner_boundary_index : simulation_state.geometry.v_outer_boundary_index ] transport_state = MonteCarloTransportState( packet_collection, geometry_state=geometry_state, opacity_state=opacity_state_numba, time_explosion=simulation_state.time_explosion, n_levels_bf_species_by_n_cells_tuple=n_levels_bf_species_by_n_cells_tuple, ) transport_state.enable_full_relativity = ( self.montecarlo_configuration.ENABLE_FULL_RELATIVITY ) configuration_initialize( self.montecarlo_configuration, self, no_of_virtual_packets ) return transport_state
[docs] def run( self, transport_state, show_progress_bars=True, ): """ Run the montecarlo calculation. Parameters ---------- transport_state : tardis.transport.montecarlo.transport_state.TransportState Transport state containing all the data needed for the Monte Carlo simulation show_progress_bars : bool Show progress bars Returns ------- v_packets_energy_hist : ndarray Histogram of energy from virtual packets """ return self.run_classic(transport_state, show_progress_bars)
[docs] def run_classic( self, transport_state, show_progress_bars=True, ): """ Run the montecarlo calculation using classic mode (no continuum). Parameters ---------- transport_state : tardis.transport.montecarlo.transport_state.TransportState Transport state containing all the data needed for the Monte Carlo simulation show_progress_bars : bool Show progress bars Returns ------- v_packets_energy_hist : ndarray Histogram of energy from virtual packets """ set_num_threads(self.nthreads) self.transport_state = transport_state number_of_vpackets = self.montecarlo_configuration.NUMBER_OF_VPACKETS number_of_rpackets = len(transport_state.packet_collection.initial_nus) if self.enable_rpacket_tracking: trackers_list = generate_tracker_full_list( number_of_rpackets, self.montecarlo_configuration.INITIAL_TRACKING_ARRAY_LENGTH, ) else: # Initialize the last interaction tracker list directly trackers_list = generate_tracker_last_interaction_list( number_of_rpackets ) # Reset packet progress bar for this iteration if show_progress_bars: reset_packet_pbar(number_of_rpackets) # Classic mode: returns 4 values (no continuum estimators) ( v_packets_energy_hist, vpacket_tracker, estimators_bulk, estimators_line, ) = montecarlo_transport( transport_state.packet_collection, transport_state.geometry_state, transport_state.time_explosion.cgs.value, transport_state.opacity_state, self.montecarlo_configuration, self.spectrum_frequency_grid.value, trackers_list, number_of_vpackets, show_progress_bars=show_progress_bars, ) # Attach estimators to transport state transport_state.estimators_bulk = estimators_bulk transport_state.estimators_line = estimators_line # Last interaction trackers are already populated directly in the list # No finalization needed with direct list approach if self.montecarlo_configuration.ENABLE_VPACKET_TRACKING and ( number_of_vpackets > 0 ): transport_state.vpacket_tracker = vpacket_tracker update_iterations_pbar(1) refresh_packet_pbar() # Need to change the implementation of rpacket_trackers_to_dataframe # Such that it also takes of the case of # RPacketLastInteractionTracker if self.enable_rpacket_tracking: self.transport_state.tracker_full_df = trackers_full_to_df( trackers_list ) self.transport_state.tracker_last_interaction_df = ( tracker_full_df2tracker_last_interaction_df( self.transport_state.tracker_full_df ) ) else: self.transport_state.tracker_full_df = None self.transport_state.tracker_last_interaction_df = ( trackers_last_interaction_to_df(trackers_list) ) transport_state.virt_logging = ( self.montecarlo_configuration.ENABLE_VPACKET_TRACKING ) return v_packets_energy_hist
[docs] @classmethod def from_config( cls, config, packet_source, enable_virtual_packet_logging=False ): """ Create a new MontecarloTransport instance from a Configuration object. Parameters ---------- config : tardis.io.config_reader.Configuration virtual_packet_logging : bool Returns ------- MontecarloTransport """ if config.plasma.disable_electron_scattering: logger.warning( "Disabling electron scattering - this is not physical." "Likely bug in formal integral - " "will not give same results." ) constants.SIGMA_THOMSON = 1e-200 else: logger.debug("Electron scattering switched on") constants.SIGMA_THOMSON = const.sigma_T.to("cm^2").value spectrum_frequency_grid = quantity_linspace( config.spectrum.stop.to("Hz", u.spectral()), config.spectrum.start.to("Hz", u.spectral()), num=config.spectrum.num + 1, ) running_mode = config.spectrum.integrated.compute.upper() if running_mode == "GPU": if cuda.is_available(): use_gpu = True else: raise ValueError( """The GPU option was selected for the formal_integral, but no CUDA GPU is available.""" ) elif running_mode == "AUTOMATIC": use_gpu = bool(cuda.is_available()) elif running_mode == "CPU": use_gpu = False else: raise ValueError( """An invalid option for compute was passed. The three valid values are 'GPU', 'CPU', and 'Automatic'.""" ) montecarlo_configuration = MonteCarloConfiguration() montecarlo_configuration.DISABLE_LINE_SCATTERING = ( config.plasma.disable_line_scattering ) montecarlo_configuration.DISABLE_ELECTRON_SCATTERING = ( config.plasma.disable_electron_scattering ) montecarlo_configuration.INITIAL_TRACKING_ARRAY_LENGTH = ( config.montecarlo.tracking.initial_array_length ) radfield_prop_solver = MCRadiationFieldPropertiesSolver( config.plasma.w_epsilon ) return cls( radfield_prop_solver=radfield_prop_solver, spectrum_frequency_grid=spectrum_frequency_grid, virtual_spectrum_spawn_range=config.montecarlo.virtual_spectrum_spawn_range, enable_full_relativity=config.montecarlo.enable_full_relativity, line_interaction_type=config.plasma.line_interaction_type, spectrum_method=config.spectrum.method, packet_source=packet_source, debug_packets=config.montecarlo.debug_packets, logger_buffer=config.montecarlo.logger_buffer, enable_virtual_packet_logging=( config.spectrum.virtual.virtual_packet_logging | enable_virtual_packet_logging ), enable_rpacket_tracking=config.montecarlo.tracking.track_rpacket, nthreads=config.montecarlo.nthreads, use_gpu=use_gpu, montecarlo_configuration=montecarlo_configuration, )