Source code for tardis.io.util

# Utility functions for the IO part of TARDIS

from __future__ import annotations

import collections.abc as collections_abc
import hashlib
import logging
import shutil
from collections import OrderedDict
from functools import lru_cache
from pathlib import Path
from typing import TYPE_CHECKING

import numpy as np
import yaml
from astropy import units as u
from astropy.utils.data import download_file

from tardis import __path__ as tardis_path
from tardis import __version__
from tardis import constants as const

if TYPE_CHECKING:
    from typing import Any

logger = logging.getLogger(__name__)


[docs] def get_internal_data_path(fname: str) -> str: """ Get internal data path of TARDIS. Parameters ---------- fname : str The filename to join with the internal data path. Returns ------- str Internal data path of TARDIS joined with the filename. """ return str(Path(tardis_path[0]) / "data" / fname)
[docs] def quantity_from_str(text: str) -> u.Quantity: """ Convert a string to `astropy.units.Quantity`. Parameters ---------- text : str The string to convert to `astropy.units.Quantity`. Expected format is "value unit", e.g., "1.0 cm" or "5 log_lsun". Returns ------- astropy.units.Quantity The converted quantity with appropriate units. Notes ----- Special handling for "log_lsun" unit which is converted to solar luminosity in CGS units. """ value_str, unit_str = text.split(None, 1) value = float(value_str) if unit_str.strip() == "log_lsun": value = 10 ** (value + np.log10(const.L_sun.cgs.value)) unit_str = "erg/s" unit = u.Unit(unit_str) if unit == u.L_sun: return value * const.L_sun return u.Quantity(value, unit_str)
[docs] class MockRegexPattern: """ A mock class to be used in place of a compiled regular expression when a type check is needed instead of a regex match. Notes ----- This is usually a lot slower than regex matching. Parameters ---------- target_type : type The target type for conversion testing. """ def __init__(self, target_type: type) -> None: """ Initialize the MockRegexPattern. Parameters ---------- target_type : type The target type for conversion testing. """ self.type = target_type
[docs] def match(self, text: str) -> bool: """ Test if text can be converted to the target type. Parameters ---------- text : str A string to be passed to `target_type` for conversion. Returns ------- bool Returns `True` if `text` can be converted to `target_type`, otherwise returns `False`. """ try: self.type(text) except ValueError: return False return True
[docs] class YAMLLoader(yaml.Loader): """ A custom YAML loader containing all the constructors required to properly parse the tardis configuration. """
[docs] def construct_quantity(self, node: yaml.ScalarNode) -> u.Quantity: """ A constructor for converting quantity-like YAML nodes to `astropy.units.Quantity` objects. Parameters ---------- node : yaml.Node The YAML node to be constructed. Returns ------- astropy.units.Quantity The constructed quantity object. """ data = self.construct_scalar(node) return quantity_from_str(data)
[docs] def mapping_constructor(self, node: yaml.MappingNode) -> OrderedDict: """ Construct an OrderedDict from a YAML mapping node. Parameters ---------- node : yaml.Node The YAML mapping node to construct. Returns ------- OrderedDict The constructed ordered dictionary. """ return OrderedDict(self.construct_pairs(node))
YAMLLoader.add_constructor("!quantity", YAMLLoader.construct_quantity) YAMLLoader.add_implicit_resolver("!quantity", MockRegexPattern(quantity_from_str), None) YAMLLoader.add_implicit_resolver( "tag:yaml.org,2002:float", MockRegexPattern(float), None ) YAMLLoader.add_constructor( yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG, YAMLLoader.mapping_constructor, )
[docs] def yaml_load_file(filename: str, loader: type = yaml.Loader) -> Any: """ Load a YAML file using the specified loader. Parameters ---------- filename : str Path to the YAML file to load. loader : type, optional YAML loader class to use, by default yaml.Loader. Returns ------- Any The loaded YAML content. """ with open(filename) as stream: return yaml.load(stream, Loader=loader)
[docs] def traverse_configs(base: Any, other: Any, func: Any, *args: Any) -> None: """ Recursively traverse a base dict or list along with another one calling `func` for leafs of both objects. Parameters ---------- base : Any The object on which the traversing is done. other : Any The object which is traversed along with `base`. func : Any A function called for each leaf of `base` and the corresponding leaf of `other`. Signature: `func(item1, item2, *args)`. *args : Any Arguments passed into `func`. """ if isinstance(base, collections_abc.Mapping): for k in base: traverse_configs(base[k], other[k], func, *args) elif ( isinstance(base, collections_abc.Iterable) and not isinstance(base, str) and not hasattr(base, "shape") ): for val1, val2 in zip(base, other): traverse_configs(val1, val2, func, *args) else: func(base, other, *args)
[docs] def assert_equality(item1: Any, item2: Any) -> None: """ Assert that two items are equal, handling special cases for units and arrays. Parameters ---------- item1 : Any First item to compare. item2 : Any Second item to compare. Raises ------ AssertionError If the items are not equal. """ assert type(item1) is type(item2) try: if hasattr(item1, "unit"): assert item1.unit == item2.unit assert np.allclose(item1, item2, atol=0.0) except (ValueError, TypeError): assert item1 == item2
[docs] def check_equality(item1: Any, item2: Any) -> bool: """ Check if two items are equal using traverse_configs and assert_equality. Parameters ---------- item1 : Any First item to compare. item2 : Any Second item to compare. Returns ------- bool True if items are equal, False otherwise. """ try: traverse_configs(item1, item2, assert_equality) except AssertionError: return False else: return True
[docs] @lru_cache(maxsize=None) def download_from_url( url: str, dst: str, checksum: str, src: tuple[str, ...] | None = None, retries: int = 3, ) -> None: """ Download files from a given URL. Parameters ---------- url : str URL to download from. dst : str Destination folder for the downloaded file. checksum : str Expected MD5 checksum of the file. src : tuple of str, optional List of URLs to use as mirrors. retries : int, optional Number of retry attempts, by default 3. Raises ------ RuntimeError If maximum number of retries is reached and checksum still doesn't match. """ cached_file_path = download_file(url, sources=src, pkgname="tardis") with open(cached_file_path, "rb") as f: new_checksum = hashlib.md5(f.read()).hexdigest() if checksum == new_checksum: shutil.copy(cached_file_path, dst) elif checksum != new_checksum and retries > 0: retries -= 1 logger.warning( "Incorrect checksum, retrying... (%d attempts remaining)", retries + 1 ) download_from_url(url, dst, checksum, src, retries) else: logger.error("Maximum number of retries reached. Aborting")