"""Collection of functions and tooling intended
for general usage.
"""
from __future__ import annotations
import datetime
import os
import shutil
import signal
import subprocess
import uuid
from contextlib import contextmanager
from pathlib import Path
from socket import gethostname
from typing import Any, Generator, NamedTuple
import astropy.units as u
import numpy as np
from astropy.coordinates import SkyCoord
from astropy.io import fits
from astropy.wcs import WCS
from flint.convol import BeamShape
from flint.exceptions import TimeLimitException
from flint.logging import logger
# TODO: This Captain is aware that there is a common fits getheader between
# a couple of functions that interact with tasks. Perhaps a common FITS properties
# struct should be considered. The the astropy.io.fits.Header might be
# appropriate to pass around between dask / prefect delayed functions. Something
# that only opens the FITS file once and places things into common field names.
[docs]
def flatten_items(items: list[Any]) -> list[Any]:
"""Recursively flatten a collection (list or tuple) of items.
Args:
items (list[Any]): Items to flatten. Can be arbitrarily nested
Returns:
list[Any]: Flattened items
"""
flat_items = []
def _flatten(nested):
for e in nested:
if isinstance(e, (list, tuple)):
_flatten(e)
else:
flat_items.append(e)
_flatten(items)
return flat_items
[docs]
def _signal_timelimit_handler(*args):
raise TimeLimitException
@contextmanager
[docs]
def timelimit_on_context(
timelimit_seconds: int | float,
) -> Generator[None, None, None]:
"""Creates a context manager that will raise ``flint.exceptions.TimelimitException``
should the control not leave the ``with`` context within an specified amount of time.
Notes:
This function **can not** be used if the function calling it is not executing
in the main thread, such as the case with ``dask``. The underlying ``signal``
module relies on being in the main thead, otherwise an ``Exception`` is raised.
Args:
timelimit_seconds (Union[int,float]): The maximum time allowed for the with context to be escaped
Raises:
TimeLimitException: Raised should the maximum timelimit be violated.
Yields:
Generator[None, None, None]: A generating function that returns nothing
"""
signal.signal(signal.SIGALRM, _signal_timelimit_handler)
signal.alarm(int(timelimit_seconds))
logger.info(f"Setting a timelimit of {timelimit_seconds=}")
try:
yield
except TimeLimitException:
logger.info(f"Timeout limit of {timelimit_seconds=} reached")
raise TimeLimitException
signal.alarm(0)
@contextmanager
[docs]
def hold_then_move_into(
move_directory: Path,
hold_directory: Path | None,
delete_hold_on_exit: bool = True,
overwrite_if_exists: bool = False,
append_uuid: bool = False,
) -> Generator[Path, None, None]:
"""Create a temporary directory such that anything within it on the
exit of the context manager is copied over to `move_directory`.
If `hold_directory` and `move_directory` are the same or `hold_directory` is None, then `move_directory`
is immediately returned and no output files are copied or deleted. `move_directory` will be
created if it does not exist.
If `append_uuid` is `True` then the returned Path will have a folder whose name is based on a
UUID. This UUID will automatically be derived and the output directory will be created. Consider
using this is `delete_hold_on_exit` is `True`, especially if running in a multi-threaded context
and the `hold_directory` is based on an environment variable (e.g. such as on SLURM).
Args:
move_directory (Path): Final directory location to move items into
hold_directory (Optional[Path], optional): Location of directory to temporarily base work from. If None provided `move_directory` is returned and no copying/deleting is performed on exit. Defaults to None.
delete_hold_on_exit (bool, optional): Whether `hold_directory` is deleted on exit of the context. Defaults to True.
overwrite_if_exists (bool, optional): If a file already exists in the move directory overwrite it with a new copy. Defaults to False.
append_uuid (bool, optional): add a folder whose name is a `uuid` to the returned Path. Defaults to False.
Returns:
Path: Path to the temporary folder
Yields:
Iterator[Path]: Path to the temporary folder
"""
# TODO: except extra files and folders to copy into `hold_directory` that are
# also placed back on exit
hold_directory = Path(hold_directory) if hold_directory else None
move_directory = Path(move_directory)
logger.info("Hold context manager")
logger.info(f"{hold_directory=}")
logger.info(f"{move_directory=}")
if append_uuid and hold_directory is not None:
uuid_directory_name = str(uuid.uuid4().hex)
hold_directory = hold_directory / uuid_directory_name
logger.info(f"Updated {hold_directory=}")
if hold_directory == move_directory or hold_directory is None:
move_directory.mkdir(parents=True, exist_ok=True)
yield move_directory
else:
for directory in (hold_directory, move_directory):
if directory.exists():
assert directory.is_dir()
else:
directory.mkdir(parents=True)
assert all([d.is_dir() for d in (hold_directory, move_directory)])
yield hold_directory
for file_or_folder in hold_directory.glob("*"):
logger.info(f"Moving {file_or_folder=} to {move_directory=}")
# TODO: Make this optional
out_files = move_directory / file_or_folder.name
if out_files.exists() and overwrite_if_exists:
logger.warn(f"{out_files=} already exists. Deleting.")
remove_files_folders(out_files)
shutil.move(str(file_or_folder), move_directory)
if delete_hold_on_exit:
remove_files_folders(hold_directory)
@contextmanager
[docs]
def temporarily_move_into(
subject: Path, temporary_directory: Path | None = None
) -> Generator[Path, None, None]:
"""Given a file or folder, temporarily copy it into the path specified
by `temporary_directory` for the duration of the context manager. Upon
exit the original copy, specified by `subject`, is removed and replaced
by the copy within `temporary_directory`.
`temporary_directory` will be created internally, and an error will be
raised if it exists.
If `temporary_directory` describes a nested path only the lowest directory
is removed.
If `temporary_directory` is None the `subject` path is returned and there
is no copying and deleting performed.
Args:
subject (Path): The file or folder to temporarily move
temporary_directory (Optional[Path], optional): The temporary directory to work with. If none the subject path is returned. Defaults to None.
Yields:
Path: The path to the temporary object
"""
subject = Path(subject)
temporary_directory = Path(temporary_directory) if temporary_directory else None
if temporary_directory is None:
yield subject
else:
temporary_directory.mkdir(parents=True, exist_ok=True)
assert temporary_directory.is_dir(), (
f"{temporary_directory=} exists and is not a folder"
)
output_item = temporary_directory / subject.name
assert not output_item.exists(), f"{output_item=} already exists! "
logger.info(f"Moving {subject=} to {output_item=}")
if subject.is_dir():
logger.info(f"{subject=} is a directory, recursively copying")
copy_directory(
input_directory=subject, output_directory=output_item.absolute()
)
else:
shutil.copy(subject, output_item)
yield output_item
logger.info(f"Moving {output_item} back to {subject=}")
remove_files_folders(subject)
shutil.move(str(output_item), subject)
logger.info(f"Removing {temporary_directory=}")
shutil.rmtree(temporary_directory)
[docs]
def parse_environment_variables(
variable: str | None, default: str | None = None
) -> str | None:
"""Expand a $VARIABLE environment variable to obtain its value from
the environment. The dollar-character is required in its input to be
expanded.
Expanding an environment variable embedded within a Path-like expression
is supported. Each “/”-delimited segment as an env-var if it exists.
Splits `variable` on "/", and for each segment:
1. Strips any leading "$".
2. If the remaining name matches an env-var, substitutes its value.
3. If it started with "$" but the env-var is unset, returns `default`.
4. Otherwise leaves the segment literal.
Rejoins the segments with "/", preserves a trailing slash if present.
Should a $VARIABLE be specified but is unresolved, the `default` value
is returned.
Some variables can be used to trigger some operation:
- $FLINT_UUID: resolves to a hex-formatted UUID
Args:
variable: e.g. "TEST1/$SLURM_TMPDIR" or "$HOME/subdir" or "literal/path"
default: returned if any "$VAR" lookup fails
Returns:
Expanded path string, `default` on lookup failure, or None if `variable` is None.
"""
if variable is None:
return None
parts = variable.split("/")
out_parts: list[str] = []
for part in parts:
if not part.startswith("$"):
out_parts.append(part)
continue
val: str | None
# Test for known directives
if part == "$FLINT_UUID":
val = str(uuid.uuid4().hex)
else:
# At this point the part should resolve to a
# environment variable, you dirty sea dog
name = part.lstrip("$")
val = os.getenv(name)
# missing placeholder → fallback
if val is None:
return default
# This can not be None because of the above. Putting
# an assert to capture this behaviour anticipating future
# refactoring
assert val is not None, f"{val=}, which is not expected"
out_parts.append(val)
result = "/".join(out_parts)
# preserve trailing slash
if variable.endswith("/") and not result.endswith("/"):
result += "/"
return result
[docs]
class SlurmInfo(NamedTuple):
"""The hostname of the slurm job"""
[docs]
job_id: str | None = None
"""The job ID of the slurm job"""
[docs]
task_id: str | None = None
"""The task ID of the slurm job"""
[docs]
time: str | None = None
"""The time time the job information was gathered"""
[docs]
def get_slurm_info() -> SlurmInfo:
"""Collect key slurm attributes of a job
Returns:
SlurmInfo: Collection of slurm items from the job environment
"""
hostname = gethostname()
job_id = parse_environment_variables("$SLURM_JOB_ID")
task_id = parse_environment_variables("$SLURM_ARRAY_TASK_ID")
time = str(datetime.datetime.now())
return SlurmInfo(hostname=hostname, job_id=job_id, task_id=task_id, time=time)
[docs]
def get_job_info(mode: str = "slurm") -> SlurmInfo:
"""Get the job information for the supplied mode
Args:
mode (str, optional): Which mode to poll information for. Defaults to "slurm".
Raises:
ValueError: Raised if the mode is not supported
Returns:
Union[SlurmInfo]: The specified mode
"""
# TODO: Add other modes? Return a default?
modes = ("slurm",)
if mode.lower() == "slurm":
job_info = get_slurm_info()
else:
raise ValueError(f"{mode=} not supported. Supported {modes=} ")
return job_info
[docs]
def log_job_environment() -> SlurmInfo:
"""Log components of the slurm environment. Currently only support slurm
Returns:
SlurmInfo: Collection of slurm items from the job environment
"""
# TODO: Expand this to allow potentially other job queue systems
slurm_info = get_slurm_info()
logger.info(f"Running on {slurm_info.hostname=}")
logger.info(f"Slurm job id is {slurm_info.job_id}")
logger.info(f"Slurm task id is {slurm_info.task_id}")
return slurm_info
[docs]
def get_beam_shape(fits_path: Path) -> BeamShape | None:
"""Construct and return a beam shape from the fields in a FITS image
Args:
fits_path (Path): FITS image to extract the beam information from
Returns:
Optional[BeamShape]: Shape of the beam stored in the FITS image. None is returned if the beam is not found.
"""
header = fits.getheader(filename=fits_path)
if not all([key in header for key in ("BMAJ", "BMIN", "BPA")]):
return None
beam_shape = BeamShape(
bmaj_arcsec=header["BMAJ"] * 3600,
bmin_arcsec=header["BMIN"] * 3600,
bpa_deg=header["BPA"],
)
return beam_shape
[docs]
def get_pixels_per_beam(fits_path: Path) -> float | None:
"""Given a image with beam information, return the number of pixels
per beam. The beam is taken from the FITS header. This is evaluated
for pixels at the reference pixel position.
Args:
fits_path (Path): FITS image to consideer
Returns:
Optional[float]: Number of pixels per beam. If beam is not in header then None is returned.
"""
beam_shape = get_beam_shape(fits_path=fits_path)
if beam_shape is None:
return None
header = fits.getheader(filename=fits_path)
pixel_ra = np.abs(header["CDELT1"] * 3600)
pixel_dec = np.abs(header["CDELT2"] * 3600)
beam_area = beam_shape.bmaj_arcsec * beam_shape.bmin_arcsec * np.pi
pixel_area = pixel_ra * pixel_dec
no_pixels = beam_area / pixel_area
return no_pixels
[docs]
def get_packaged_resource_path(package: str, filename: str) -> Path:
"""Load in the path of a package sources.
The `package` argument is passed as a though the module
is being specified as an import statement: `flint.data.aoflagger`.
Args:
package (str): The module path to the resources
filename (str): Filename of the datafile to load
Returns:
Path: The absolute path to the packaged resource file
"""
logger.info(f"Loading {package=} for {filename=}")
try:
import importlib_resources as importlib_resources
except ImportWarning:
from importlib import resources as importlib_resources
p = importlib_resources.files(package)
logger.info(f"{p=}")
full_path = Path(p) / filename # type: ignore
logger.debug(f"Resolved {full_path=}")
return full_path
[docs]
def estimate_skycoord_centre(
sky_positions: SkyCoord, final_frame: str = "fk5"
) -> SkyCoord:
"""Estimate the centre position of (RA, Dec) positions stored in a
input `SkyCoord`. Internally the mean of the cartesian (X,Y,Z) is
calculated, which is then transformed back to a sky position,.
Args:
sky_positions (SkyCoord): Set of input positions to consider
final_frame (str, optional): The frame of the returned `SkyCoord` objects set using `.transform_to`. Defaults to fk5.
Returns:
SkyCoord: Centre position
"""
xyz_positions = sky_positions.cartesian.xyz
xyz_mean_position = np.mean(xyz_positions, axis=1)
mean_position = SkyCoord(
*xyz_mean_position, representation_type="cartesian"
).transform_to(final_frame)
return mean_position
[docs]
def estimate_image_centre(image_path: Path) -> SkyCoord:
with fits.open(image_path, memmap=True) as open_image:
image_header = open_image[0].header
image_shape = open_image[0].data.shape
wcs = WCS(image_header)
centre_pixel = np.array(image_shape) / 2.0
# The celestial deals with the radio image potentially having four dimensions
# (stokes, frequencyes, ra, dec)
centre_sky = wcs.celestial.pixel_to_world(centre_pixel[0], centre_pixel[1])
return centre_sky
[docs]
def zip_folder(
in_path: Path, out_zip: Path | None = None, archive_format: str = "tar"
) -> Path:
"""Zip a directory and remove the original.
Args:
in_path (Path): The path that will be zipped up.
out_zip (Path, optional): Name of the output file. A `archive_format` extension will be added by `shutil.make_archive`. Defaults to None.
archive_format (str, optional): The format of the archive. See `shutil.make_archive`. Defaults to "tar".
Returns:
Path: the path of the compressed zipped folder
"""
in_path = Path(in_path)
out_zip = in_path if out_zip is None else out_zip
if in_path.exists():
logger.info(f"Zipping {in_path}.")
shutil.make_archive(str(out_zip), format=archive_format, base_dir=str(in_path))
remove_files_folders(in_path)
else:
logger.warning(f"{in_path=} does not exist... Not archiving. ")
return out_zip
[docs]
def rsync_copy_directory(target_path: Path, out_path: Path) -> Path:
"""A small attempt to rsync a directory from one location to another.
This is an attempt to verify a copy was completed successfully.
Args:
target_path (Path): The target directory to copy
out_path (Path): The location to copy the directory to
Returns:
Path: The output path of the new directory.
"""
rsync_cmd = f"rsync -avh --progress --stats {target_path!s}/ {out_path!s}/ "
logger.info(f"Rsync copying {target_path} to {out_path}.")
logger.debug(f"Will run {rsync_cmd}")
rsync_run = subprocess.Popen(rsync_cmd.split(), stdout=subprocess.PIPE)
if rsync_run.stdout is not None:
for line in rsync_run.stdout:
logger.debug(line.decode().rstrip())
return out_path
[docs]
def copy_directory(
input_directory: Path,
output_directory: Path,
verify: bool = False,
overwrite: bool = False,
) -> Path:
"""Copy a directory into a new location.
Args:
input_directory (Path): The source directory to copy
output_directory (Path): The location of the source directory to copy to
verify (bool, optional): Attempt to run `rsync` to verify copy worked. Defaults to False.
overwrite (bool, optional): Remove the target direcrtory if it exists. Defaults to False.
Returns:
Path: Location of output directory
"""
input_directory = Path(input_directory)
output_directory = Path(output_directory)
assert input_directory.exists() and input_directory.is_dir(), (
f"Currently only supports copying directories, {input_directory=} is a file or does not exist. "
)
logger.info(f"Copying {input_directory} to {output_directory}.")
if output_directory.exists():
if overwrite:
logger.warning(f"{output_directory} already exists. Removing it. ")
remove_files_folders(output_directory)
shutil.copytree(input_directory, output_directory)
if verify:
rsync_copy_directory(input_directory, output_directory)
return output_directory
[docs]
def remove_files_folders(*paths_to_remove: Path) -> list[Path]:
"""Will remove a set of paths from the file system. If a Path points
to a folder, it will be recursively removed. Otherwise it is simply
unlinked.
Args:
paths_to_remove (Path): Set of Paths that will be removed
Returns:
List[Path]: Set of Paths that were removed
"""
files_removed = []
file: Path
for file in paths_to_remove:
file = Path(file)
if not file.exists():
logger.debug(f"{file} does not exist. Skipping, ")
continue
if file.is_dir():
logger.info(f"Removing folder {file!s}")
shutil.rmtree(file)
else:
logger.info(f"Removing file {file}.")
file.unlink()
files_removed.append(file)
return files_removed
[docs]
def create_directory(directory: Path, parents: bool = True) -> Path:
"""Will attempt to safely create a directory. Should it
not exist it will be created. if this creates an exception,
which might happen in a multi-process environment, it is
reported and discarded.
Args:
directory (Path): Path to directory to create
parents (bool, optional): Create parent directories if necessary. Defaults to True.
Returns:
Path: The directory created
"""
directory = Path(directory)
logger.info(f"Creating {directory!s}")
try:
directory.mkdir(parents=parents, exist_ok=True)
except Exception as e:
logger.error(f"Failed to create {directory!s} {e}.")
return directory