Source code for tardis.transport.montecarlo.modes.classic.montecarlo_transport

"""Classic Monte Carlo transport - line-only mode without continuum processes."""

import numpy as np
from numba import njit, objmode, prange
from numba.np.ufunc.parallel import get_num_threads, get_thread_id
from numba.typed import List

from tardis.model.geometry.radial1d import NumbaRadial1DGeometry
from tardis.opacities.opacity_state_numba import OpacityStateNumba
from tardis.transport.montecarlo import njit_dict
from tardis.transport.montecarlo.configuration.base import (
    MonteCarloConfiguration,
)
from tardis.transport.montecarlo.estimators.estimators_bulk import (
    create_estimators_bulk_list,
    init_estimators_bulk,
)
from tardis.transport.montecarlo.estimators.estimators_line import (
    create_estimators_line_list,
    init_estimators_line,
)
from tardis.transport.montecarlo.modes.classic.packet_propagation import (
    packet_propagation,
)
from tardis.transport.montecarlo.packets.packet_collections import (
    PacketCollection,
    VPacketCollection,
    consolidate_vpacket_tracker,
)
from tardis.transport.montecarlo.packets.radiative_packet import (
    PacketStatus,
    RPacket,
)
from tardis.transport.montecarlo.progress_bars import update_packets_pbar


[docs] @njit(**njit_dict) def montecarlo_transport( packet_collection: PacketCollection, geometry_state_numba: NumbaRadial1DGeometry, time_explosion: float, opacity_state_numba: OpacityStateNumba, montecarlo_configuration: MonteCarloConfiguration, spectrum_frequency_grid: np.ndarray, trackers: List, number_of_vpackets: int, show_progress_bars: bool, ) -> tuple[ np.ndarray, VPacketCollection, type, type, ]: """ Main loop of the Monte Carlo radiative transfer routine for classic mode. Classic mode implements line-only transport without continuum processes. Parameters ---------- packet_collection : PacketCollection Collection containing initial packet properties (positions, directions, frequencies, energies, and seeds) geometry_state_numba : NumbaRadial1DGeometry Numba-compiled simulation geometry containing shell boundaries and velocity information time_explosion : float Time since explosion in seconds, used for relativistic calculations opacity_state_numba : OpacityStateNumba Numba-compiled opacity state containing line opacities and atomic data required for interactions montecarlo_configuration : MonteCarloConfiguration Configuration object containing Monte Carlo simulation parameters and flags for various physics modules spectrum_frequency_grid : np.ndarray Frequency grid array for virtual packet spectrum calculation trackers : List List of packet trackers for detailed packet interaction logging number_of_vpackets : int Number of virtual packets to spawn per real packet interaction show_progress_bars : bool Flag to enable/disable progress bar updates during simulation Returns ------- tuple[np.ndarray, VPacketCollection, type, type] A tuple containing: - v_packets_energy_hist : Energy histogram of virtual packets binned by frequency - vpacket_tracker : Consolidated virtual packet collection - estimators_bulk : Updated bulk radiation field estimator object - estimators_line : Updated line radiation field estimator object """ no_of_packets = len(packet_collection.initial_nus) v_packets_energy_hist = np.zeros_like(spectrum_frequency_grid) delta_nu = spectrum_frequency_grid[1] - spectrum_frequency_grid[0] # Pre-allocate a list of vpacket collections for later storage vpacket_collections = List() for i in range(no_of_packets): vpacket_collections.append( VPacketCollection( i, spectrum_frequency_grid, montecarlo_configuration.VPACKET_SPAWN_START_FREQUENCY, montecarlo_configuration.VPACKET_SPAWN_END_FREQUENCY, number_of_vpackets, montecarlo_configuration.TEMPORARY_V_PACKET_BINS, ) ) # Get the ID of the main thread and the number of threads main_thread_id = get_thread_id() n_threads = get_num_threads() # betting get thread_id goes from 0 to num threads # Note that get_thread_id() returns values from 0 to n_threads-1, # so we iterate from 0 to n_threads-1 to create the estimator_lists # Initialize estimators n_lines_by_n_cells_tuple = opacity_state_numba.tau_sobolev.shape n_cells = len(geometry_state_numba.r_inner) estimators_bulk = init_estimators_bulk(n_cells) estimators_line = init_estimators_line(n_lines_by_n_cells_tuple) # Initialize thread-local estimators estimators_bulk_list_thread = create_estimators_bulk_list( n_cells, n_threads ) estimators_line_list_thread = create_estimators_line_list( n_lines_by_n_cells_tuple, n_threads ) for i in prange(no_of_packets): thread_id = get_thread_id() if show_progress_bars: if thread_id == main_thread_id: with objmode: update_amount = 1 * n_threads update_packets_pbar( update_amount, no_of_packets, ) r_packet = RPacket( packet_collection.initial_radii[i], packet_collection.initial_mus[i], packet_collection.initial_nus[i], packet_collection.initial_energies[i], packet_collection.packet_seeds[i], i, ) # Seed the random number generator np.random.seed(r_packet.seed) # Get the thread-local estimators for this thread estimators_bulk_thread = estimators_bulk_list_thread[thread_id] estimators_line_thread = estimators_line_list_thread[thread_id] # Get the thread-local v_packet_collection for this thread vpacket_collection = vpacket_collections[i] # RPacket Tracker for this thread tracker = trackers[i] loop = packet_propagation( r_packet, geometry_state_numba, time_explosion, opacity_state_numba, estimators_bulk_thread, estimators_line_thread, vpacket_collection, tracker, montecarlo_configuration, ) packet_collection.output_nus[i] = r_packet.nu if r_packet.status == PacketStatus.REABSORBED: packet_collection.output_energies[i] = -r_packet.energy elif r_packet.status == PacketStatus.EMITTED: packet_collection.output_energies[i] = r_packet.energy # Finalize the tracker (e.g. trim arrays to actual size) tracker.finalize() # Finalize the vpacket collection to trim arrays to actual size vpacket_collection.finalize_arrays() v_packets_idx = np.floor( (vpacket_collection.nus - spectrum_frequency_grid[0]) / delta_nu ).astype(np.int64) for j, idx in enumerate(v_packets_idx): if (vpacket_collection.nus[j] < spectrum_frequency_grid[0]) or ( vpacket_collection.nus[j] > spectrum_frequency_grid[-1] ): continue v_packets_energy_hist[idx] += vpacket_collection.energies[j] for estimator_thread in estimators_bulk_list_thread: estimators_bulk.increment(estimator_thread) for estimator_thread in estimators_line_list_thread: estimators_line.increment(estimator_thread) if montecarlo_configuration.ENABLE_VPACKET_TRACKING: vpacket_tracker = consolidate_vpacket_tracker( vpacket_collections, spectrum_frequency_grid, montecarlo_configuration.VPACKET_SPAWN_START_FREQUENCY, montecarlo_configuration.VPACKET_SPAWN_END_FREQUENCY, ) else: vpacket_tracker = VPacketCollection( -1, spectrum_frequency_grid, montecarlo_configuration.VPACKET_SPAWN_START_FREQUENCY, montecarlo_configuration.VPACKET_SPAWN_END_FREQUENCY, -1, 1, ) return ( v_packets_energy_hist, vpacket_tracker, estimators_bulk, estimators_line, )