You can interact with this notebook online: Launch notebook

Profiling the effect of input parameters on TARDIS runtime

[1]:
import tardis
from tardis import run_tardis

from tardis.io.config_reader import Configuration
from tardis.io.atom_data.base import AtomData
from tardis.simulation import Simulation
from tardis import run_tardis
from tardis.io.atom_data import download_atom_data
download_atom_data('kurucz_cd23_chianti_H_He')

import time
import numpy as np
from astropy import units as u
import csv
import statistics
import matplotlib.pyplot as plt

Functions for building the dictionaries and interacting with files

[2]:
def build_dictionary(fp):
    """
    Uses the fp generated from open_file to return a dictionary. Then closes
    the file at the end of execution.

    'testing_param' : dict
        'iteration' : time

    Parameters
    ----------
    fp : file

    Returns
    -------
    data_dict : dictionary
    """

    data_dict = {}
    reader = csv.reader(fp)
    for line in reader:
        testing_param = int(line[0])

        iteration = int(line[1])

        time_of_run = float(line[2])

        if testing_param not in data_dict:
            data_dict[testing_param] = {}
        if iteration not in data_dict[testing_param]:
            data_dict[testing_param][iteration] = time_of_run
    fp.close()
    return data_dict

def add_to_dictionary(data_dict, info):
    """
    Adds the data to the dictionary

    Parameters
    ----------
    data_dict : dictionary
    info : list
        It is set up as [testing_param, iteration, run_time]
    """

    testing_param = info[0]
    iteration = info[1]
    time_of_run = info[2]
    if testing_param not in data_dict:
        data_dict[testing_param] = {}
    data_dict[testing_param][iteration] = time_of_run


def check_existance(data_dict, info):
    """
    Checks to see if a data point already exists in data_dict.
    Returns true if it exists, false if it does not.

    Parameters
    ----------
    data_dict : dictionary
    info : list
        [testing_param, iteration]

    Returns
    -------
    Boolean
    """

    testing_param = info[0]
    iteration = info[1]
    #Checks if it thread already exists
    if testing_param not in data_dict:
        return False
    #If it does
    elif (testing_param in data_dict):
        #Does the iteration exist yet
        if iteration not in data_dict[testing_param]:
            return False
    return True

Timing of a specific parameter in TARDIS

Number of Threads

[3]:
def generate_data(thread_count,
                  iteration_count,
                  thread_dictionary,
                  cpu_timing_dictionary,
                  tardis_file,
                  thread_file,
                  cpu_file,):
    """
    This function generates data for running a tardis file on different numbers of threads different times.
    It writes the data to two dictionaries and their respective files.

    This function can be modified to write to as many or as few dictionaries and files as necessary.
    In order to test different parameters, the index into the tardis_config must be adjusted. This must be done
    manually in order to select the correct testing parameter.

    Parameters
    ----------
    thread_count : int
        Max number of threads to be tested
    iteration_count : int
        Amount of testing on each thread
    thread_dictionary : dict
    cpu_timing_dictionary : dict
    tardis_file : str
        TARDIS config file
    thread_file : str
        .csv file name
    cpu_file : str
        .csv file name
    """

    tardis_config = Configuration.from_yaml(tardis_file)

    for threads in range(1, thread_count+1):
        for i in range(1, iteration_count+1):
            #Continue if this data point already exists. If it exists in one it exists in both
            if check_existance(thread_dictionary, [threads, i]):
                continue

            print(threads, " : ", i)
            tardis_config["montecarlo"]["nthreads"] = threads

            sim = Simulation.from_config(tardis_config)
            start = time.monotonic()
            sim.run_convergence()
            sim.run_final()
            end = time.monotonic()
            run_time = end-start

            add_to_dictionary(thread_dictionary, [threads, i, run_time])
            add_to_dictionary(cpu_timing_dictionary, [threads, i, (run_time * threads)])

            with open(thread_file, "a") as data_file:
                data_file.write("{}, {}, {}\n".format(threads, i, run_time))
            with open(cpu_file, "a") as data_file:
                data_file.write("{}, {}, {}\n".format(threads, i, run_time * threads))

    print("Completed")
[4]:
TARDIS_FILE = 'tardis_example.yml'
STORAGE_FILE = 'threads_dict_128_threads_1_node.csv'
CPU_STORAGE_FILE = 'cpu_threads_dict_128_threads_1_node.csv'

#Compile the numba functions
run_tardis(TARDIS_FILE, log_level="warning")

#Opens file and builds the dictionary
with open(STORAGE_FILE, "r") as threads_fp:
    threads_dict = build_dictionary(threads_fp)

with open(CPU_STORAGE_FILE, "r") as cpu_fp:
    cpu_threads_dict = build_dictionary(cpu_fp)

#Creates list of main data
threads_numbers = [i for i in range(1, 129, 1)]
cpu_threads_numbers = [i for i in range(1, 129, 1)]

generate_data(128, 30, threads_dict, cpu_threads_dict, TARDIS_FILE, STORAGE_FILE, CPU_STORAGE_FILE)
Completed
[5]:
def create_stat_dictionaries(main_dict):
    """
    This builds two dictionaries, one that holds the mean of the data and the other that holds the stdev

    Parameters
    ----------
    main_dict : dict
        Data dictionary, format of
        testing_param : dictionary
            iteration : time

    Returns
    -------
    main_dict_mean : dict
    main_dict_stdev : dict
    """

    main_dict_list = {}
    for thread, iteration_dict in main_dict.items():
        main_dict_list[thread] = [v for (k, v) in iteration_dict.items()]

    main_dict_mean = {}
    main_dict_stdev = {}
    for thread, timing_list in main_dict_list.items():
        main_dict_mean[thread] = statistics.mean(timing_list)
        main_dict_stdev[thread] = statistics.stdev(timing_list)

    return main_dict_mean, main_dict_stdev

threads_dict_mean, threads_dict_stdev = create_stat_dictionaries(threads_dict)
cpu_threads_dict_mean, cpu_threads_dict_stdev = create_stat_dictionaries(cpu_threads_dict)
[6]:
def graph_data(testing_params, mean_dict, stdev_dict, x_axis_title, y_axis_title, graph_title, text_labels=False):
    """
    Uses pyplot to graph the data given the mean and standard deviation of data.

    Parameters
    ----------
    testing_params : list
        This is a list of the testing parameters, which is used for the axis construction.
    mean_dict : dict
    stdev_dict : dict
    x_axis_title : str
    y_axis_title : str
    graph_title : str
    text_labels : bool
        This is a boolean parameter for if the testing parameters are text. If they are text, then
        it will set each tick mark on the x-axis to be a label in testing_params.
    """

    lists = sorted(mean_dict.items())
    lists_stdev = sorted(stdev_dict.items())
    x, y = zip(*lists)
    x_stdev, y_stdev = zip(*lists_stdev)

    fig = plt.figure(figsize=(20,10))
    ax = fig.add_subplot(1,1,1)
    ax.set_xlabel(x_axis_title)
    ax.set_ylabel(y_axis_title)
    ax.set_title(graph_title)
    ax.grid(True)

    if text_labels:
        xtick_spacer = np.arange(0, len(testing_params), 1)
        xticks = testing_params
        ax.set_xticks(xtick_spacer)
        ax.set_xticklabels(xticks, rotation='horizontal', fontsize=18)
    elif not text_labels:
        #If there are a lot of ticks, it will reduce for visibility.
        interval = testing_params[1] - testing_params[0]
        if len(testing_params) > 80:
            xticks = np.arange(testing_params[0], testing_params[-1]+interval, interval*2)
        else:
            xticks = np.arange(testing_params[0], testing_params[-1]+interval, interval)
        ax.set_xticks(xticks)

    ax.plot(x, y, color='tab:blue', label="{} vs {}".format(x_axis_title, y_axis_title))
    plt.errorbar(x, y, yerr=y_stdev, label='stdev')
    leg = plt.legend(loc='best', ncol=1, mode='expanded', fancybox=False)
    leg.get_frame().set_alpha(0.5)

graph_data(threads_numbers, threads_dict_mean, threads_dict_stdev, "threads", "time (sec)", "threads vs time")
graph_data(cpu_threads_numbers, cpu_threads_dict_mean, cpu_threads_dict_stdev, "threads", "time (sec)", "threads vs cpu time")
../../../_images/contributing_tools_profiling_tardis_profiling_threads_9_0.png
../../../_images/contributing_tools_profiling_tardis_profiling_threads_9_1.png