import os
import re
import logging
import tempfile
import fileinput
import networkx as nx
from tardis.plasma.exceptions import PlasmaMissingModule, NotInitializedModule
from tardis.plasma.properties.base import *
from tardis.io.util import PlasmaWriterMixin
logger = logging.getLogger(__name__)
[docs]
class BasePlasma(PlasmaWriterMixin):
outputs_dict = {}
hdf_name = "plasma"
def __init__(self, plasma_properties, property_kwargs=None, **kwargs):
self.outputs_dict = {}
self.input_properties = []
self.plasma_properties = self._init_properties(
plasma_properties, property_kwargs, **kwargs
)
self._build_graph()
self.update(**kwargs)
def __getattr__(self, item):
if item in self.outputs_dict:
return self.get_value(item)
else:
super(BasePlasma, self).__getattribute__(item)
def __setattr__(self, key, value):
if key != "module_dict" and key in self.outputs_dict:
raise AttributeError(
"Plasma inputs can only be updated using " "the 'update' method"
)
else:
super(BasePlasma, self).__setattr__(key, value)
def __dir__(self):
attrs = [item for item in self.__dict__ if not item.startswith("_")]
attrs += [
item for item in self.__class__.__dict__ if not item.startswith("_")
]
attrs += self.outputs_dict.keys()
return attrs
@property
def plasma_properties_dict(self):
return {item.name: item for item in self.plasma_properties}
[docs]
def get_value(self, item):
return getattr(self.outputs_dict[item], item)
def _build_graph(self):
"""
Builds the directed Graph using network X
:param plasma_modules:
:return:
"""
self.graph = nx.DiGraph()
# Adding all nodes
self.graph.add_nodes_from(
[
(plasma_property.name, {})
for plasma_property in self.plasma_properties
]
)
# Flagging all input modules
self.input_properties = [
item
for item in self.plasma_properties
if not hasattr(item, "inputs")
]
for plasma_property in self.plasma_properties:
# Skipping any module that is an input module
if plasma_property in self.input_properties:
continue
for input in plasma_property.inputs:
if input not in self.outputs_dict:
raise PlasmaMissingModule(
f"Module {plasma_property.name} requires input "
f"{input} which has not been added"
f" to this plasma"
)
try:
position = self.outputs_dict[input].outputs.index(input)
label = self.outputs_dict[input].latex_name[position]
label = "$" + label + "$"
label = label.replace("\\", "\\\\")
except:
label = input.replace("_", "-")
self.graph.add_edge(
self.outputs_dict[input].name,
plasma_property.name,
label=label,
)
def _init_properties(
self, plasma_properties, property_kwargs=None, **kwargs
):
"""
Builds a dictionary with the plasma module names as keys
Parameters
----------
plasma_modules : list
list of Plasma properties
property_kwargs : dict
dict of plasma module : kwargs pairs. kwargs should be a dict
of arguments that will be passed to the __init__ method of
the respective plasma module.
kwargs : dictionary
input values for input properties. For example,
t_rad=[5000, 6000,], j_blues=[..]
"""
if property_kwargs is None:
property_kwargs = {}
plasma_property_objects = []
self.previous_iteration_properties = []
self.outputs_dict = {}
for plasma_property in plasma_properties:
if issubclass(plasma_property, PreviousIterationProperty):
current_property_object = plasma_property(
**property_kwargs.get(plasma_property, {})
)
current_property_object.set_initial_value(kwargs)
self.previous_iteration_properties.append(
current_property_object
)
elif issubclass(plasma_property, Input):
if not set(kwargs.keys()).issuperset(plasma_property.outputs):
missing_input_values = set(plasma_property.outputs) - set(
kwargs.keys()
)
raise NotInitializedModule(
f"Input {missing_input_values} required for "
f"plasma but not given when "
f"instantiating the "
f"plasma"
)
current_property_object = plasma_property(
**property_kwargs.get(plasma_property, {})
)
else:
current_property_object = plasma_property(
self, **property_kwargs.get(plasma_property, {})
)
for output in plasma_property.outputs:
self.outputs_dict[output] = current_property_object
plasma_property_objects.append(current_property_object)
return plasma_property_objects
[docs]
def store_previous_properties(self):
for property in self.previous_iteration_properties:
p = property.outputs[0]
self.outputs_dict[p].set_value(
self.get_value(re.sub(r"^previous_", "", p))
)
[docs]
def update(self, **kwargs):
for key in kwargs:
if key not in self.outputs_dict:
raise PlasmaMissingModule(
f"Trying to update property {key}" f" that is unavailable"
)
self.outputs_dict[key].set_value(kwargs[key])
for module_name in self._resolve_update_list(kwargs.keys()):
self.plasma_properties_dict[module_name].update()
[docs]
def freeze(self, *args):
"""
Freeze plama properties.
This method freezes plasma properties to prevent them from being
updated: the values of a frozen property are fixed in the plasma
calculation. This is useful for example for setting up test cases.
Parameters
----------
args : iterable of str
Names of plasma properties to freeze.
Examples
--------
>>> plasma.freeze('t_electrons')
"""
for key in args:
if key not in self.outputs_dict:
raise PlasmaMissingModule(
"Trying to freeze property {0}"
" that is unavailable".format(key)
)
self.outputs_dict[key].frozen = True
[docs]
def thaw(self, *args):
"""
Thaw plama properties.
This method thaws (unfreezes) plasma properties allowing them to be
updated again.
Parameters
----------
args : iterable of str
Names of plasma properties to unfreeze.
Examples
--------
>>> plasma.thaw('t_electrons')
"""
for key in args:
if key not in self.outputs_dict:
raise PlasmaMissingModule(
"Trying to thaw property {0}"
" that is unavailable".format(key)
)
self.outputs_dict[key].frozen = False
def _update_module_type_str(self):
for node in self.graph:
self.outputs_dict[node]._update_type_str()
def _resolve_update_list(self, changed_properties):
"""
Returns a list of all plasma models which are affected by the
changed_modules due to there dependency in the
the plasma_graph.
Parameters
----------
changed_modules : list
all modules changed in the plasma
Returns
-------
: list
all affected modules.
"""
descendants_ob = []
for plasma_property in changed_properties:
node_name = self.outputs_dict[plasma_property].name
descendants_ob += nx.descendants(self.graph, node_name)
descendants_ob = list(set(descendants_ob))
sort_order = list(nx.topological_sort(self.graph))
descendants_ob.sort(key=lambda val: sort_order.index(val))
logger.debug(
f"Updating modules in the following order:"
f'{"->".join(descendants_ob)}'
)
return descendants_ob
[docs]
def write_to_dot(self, fname, args=None, latex_label=True):
"""
This method takes the NetworkX Graph generated from the _build_graph
method, converts it into a DOT code, and saves it to a file
Parameters
----------
fname: str
the name of the file the graph will be saved to
args: list
a list of optional settings for displaying the
graph written in DOT format
latex_label: boolean
enables/disables writing LaTeX equations and
edge labels into the file.
"""
try:
import pygraphviz
except:
logger.warn(
"pygraphviz missing. Plasma graph will not be " "generated."
)
return
print_graph = self.graph.copy()
print_graph = self.remove_hidden_properties(print_graph)
for node in print_graph:
if latex_label == True:
if hasattr(self.plasma_properties_dict[node], "latex_formula"):
print_graph.nodes[str(node)][
"label"
] = f"\\\\textrm{{{node}: }}"
node_list = self.plasma_properties_dict[node]
formulae = node_list.latex_formula
for output in range(0, len(formulae)):
formula = formulae[output]
label = formula.replace("\\", "\\\\")
print_graph.nodes[str(node)]["label"] += label
else:
print_graph.nodes[str(node)][
"label"
] = f"\\\\textrm{{{node}}}"
else:
print_graph.nodes[str(node)]["label"] = node
for edge in print_graph.edges:
label = print_graph.edges[edge]["label"]
print_graph.edges[edge]["label"] = " "
if latex_label == True:
print_graph.edges[edge]["texlbl"] = label
nx.drawing.nx_agraph.write_dot(print_graph, fname)
for line in fileinput.FileInput(fname, inplace=1):
if latex_label == True:
print(
line.replace(
r'node [label="\N"]',
r'node [texmode="math"]',
),
end="",
)
else:
print(
line.replace(
r'node [label="\N"];',
"",
),
end="",
)
if args is not None:
with open(fname, "r") as file:
lines = file.readlines()
for newline in args:
lines.insert(1, f"\t{newline};\n")
with open(fname, "w") as f:
lines = "".join(lines)
f.write(lines)
[docs]
def write_to_tex(self, fname_graph, scale=0.5, args=None, latex_label=True):
"""
This method takes the NetworkX Graph generated from the _build_graph
method, converts it into a LaTeX friendly format,
and saves it to a file
Parameters
----------
fname_graph: str
the name of the file the graph will be saved to
args: list
a list of optional settings for displaying the
graph written in DOT format
scale: float
a scaling factor to expand/contract the generated
graph
latex_label: boolean
enables/disables writing LaTeX equations and
edge labels into the file.
"""
try:
import dot2tex
except:
logger.warn(
"dot2tex missing. Plasma graph will not be " "generated."
)
return
temp_fname = tempfile.NamedTemporaryFile().name
self.write_to_dot(temp_fname, args=args, latex_label=latex_label)
with open(temp_fname, "r") as file:
dot_string = file.read().replace("\\\\", "\\")
texcode = dot2tex.dot2tex(
dot_string, format="tikz", crop=True, valignmode="dot"
)
with open(fname_graph, "w") as file:
file.write(texcode)
for line in fileinput.input(fname_graph, inplace=1):
print(
line.replace(
r"\documentclass{article}",
r"\documentclass[class=minimal,border=20pt]{standalone}",
),
end="",
)
for line in fileinput.input(fname_graph, inplace=1):
print(line.replace(r"\enlargethispage{100cm}", ""), end="")
for line in fileinput.input(fname_graph, inplace=1):
print(
line.replace(
r"\begin{tikzpicture}[>=latex',line join=bevel,]",
r"\begin{tikzpicture}"
r"[>=latex',line join=bevel,"
rf"scale={scale}]",
),
end="",
)
[docs]
def remove_hidden_properties(self, print_graph):
for item in self.plasma_properties_dict.values():
module = self.plasma_properties_dict[item.name].__class__
if issubclass(module, HiddenPlasmaProperty):
output = module.outputs[0]
for value in self.plasma_properties_dict.keys():
if output in getattr(
self.plasma_properties_dict[value], "inputs", []
):
for input in self.plasma_properties_dict[
item.name
].inputs:
try:
position = self.outputs_dict[
input
].outputs.index(input)
label = self.outputs_dict[input].latex_name[
position
]
label = "$" + label + "$"
label = label.replace("\\", "\\\\")
except:
label = input.replace("_", "-")
self.graph.add_edge(
self.outputs_dict[input].name,
value,
label=label,
)
print_graph.remove_node(str(item.name))
return print_graph