"""This is a workflow to subtract a continuum model and image the channel-wise data
Unlike the continuum imaging and self-calibnration pipeline this flow currently
expects that all measurement sets are in the flint format, which means other than
the naming scheme that they have been been preprocessed to place them in the IAU
frame and have had their fields table updated. That is to say that they have
already been preprocessed and fixed.
"""
from __future__ import annotations
from pathlib import Path
from time import sleep
from typing import Any
import dask
import numpy as np
from configargparse import ArgumentParser
from fitscube.combine_fits import combine_fits
from prefect import flow, tags, task, unmapped
from flint.coadd.linmos import LinmosResult
from flint.configuration import get_options_from_strategy, load_and_copy_strategy
from flint.exceptions import FrequencyMismatchError
from flint.imager.wsclean import WSCleanResult
from flint.logging import logger
from flint.ms import (
MS,
consistent_ms_frequencies,
find_mss,
get_freqs_from_ms,
get_times_from_ms,
)
from flint.naming import get_sbid_from_path
from flint.options import (
AddModelSubtractFieldOptions,
FitsCubeOptions,
SubtractFieldOptions,
add_options_to_parser,
create_options_from_parser,
)
from flint.predict.crystalball import CrystalBallOptions
from flint.prefect.clusters import get_dask_runner
from flint.prefect.common.imaging import (
task_common_beam_convolve_linmos,
task_wsclean_imager,
)
from flint.prefect.common.ms import task_subtract_model_from_ms
from flint.prefect.common.predict import task_addmodel_to_ms, task_crystalball_to_ms
# These improve the stability of the distributed dask cluster, particularly around
# the usage of crystalball prediction
dask.config.set({"distributed.comm.retry.count": 20})
dask.config.set({"distributed.comm.timeouts.connect": 30})
dask.config.set({"distributed.worker.memory.terminate": False})
[docs]
def _check_and_verify_options(
subtract_field_options: SubtractFieldOptions | None = None,
addmodel_subtract_field_options: AddModelSubtractFieldOptions | None = None,
crystalball_subtract_field_options: CrystalBallOptions | None = None,
) -> None:
"""Verify that the options supplied to run the subtract field options make sense"""
if subtract_field_options:
assert (
subtract_field_options.wsclean_container.exists()
and subtract_field_options.wsclean_container.is_file()
), (
f"{subtract_field_options.wsclean_container=} does not exist or is not a file"
)
assert (
subtract_field_options.yandasoft_container.exists()
and subtract_field_options.yandasoft_container.is_file()
), (
f"{subtract_field_options.yandasoft_container=} does not exist or is not a file"
)
if subtract_field_options and subtract_field_options.use_addmodel:
assert addmodel_subtract_field_options, (
"Addmodel usage invoked by no AddModelOptions supplied"
)
assert addmodel_subtract_field_options.calibrate_container is not None, (
"Calibrate container path is needede for addmodel"
)
assert (
addmodel_subtract_field_options.calibrate_container.exists()
and addmodel_subtract_field_options.calibrate_container.is_file()
), (
f"Calibrate container {addmodel_subtract_field_options.calibrate_container} is not a file"
)
assert addmodel_subtract_field_options.addmodel_cluster_config is not None, (
f"{addmodel_subtract_field_options.addmodel_cluster_config=}, which should not happen"
)
if subtract_field_options:
assert (
sum(
[
subtract_field_options.use_addmodel,
subtract_field_options.use_crystalball,
]
)
<= 1
), "Both addmodel and crystallball selected. This surely can not be right. "
[docs]
def find_mss_to_image(
mss_parent_path: Path,
expected_ms_count: int | None = None,
data_column: str = "CORRECTED_DATA",
model_column: str = "MODEL_DATA",
) -> tuple[MS, ...]:
"""Search for MSs to image. See ``flint.ms.find_mss`` for further details.
Args:
mss_parent_path (Path): Path to search for MSs in
expected_ms_count (Optional[int], optional): Expected number of measurement sets to find. Defaults to None.
data_column (str, optional): The nominated data column that should eb set. Defaults to "CORRECTED_DATA".
model_column (str, optional): The nominated model data column that should be set. Defaults to "MODEL_DATA".
Returns:
Tuple[MS, ...]: Collect of MSs
"""
science_mss = find_mss(
mss_parent_path=mss_parent_path,
expected_ms_count=expected_ms_count,
data_column=data_column,
model_column=model_column,
)
logger.info(f"Found {science_mss=}")
return science_mss
[docs]
def find_and_setup_mss(
science_path_or_mss: Path | tuple[MS, ...],
expected_ms_count: int,
data_column: str,
) -> tuple[MS, ...]:
"""Search for MSs in a directory and, if necessary, perform checks around
their consistency. If the input data appear to be collection of MSs already
assume they have already been set and checked for consistency.
Args:
science_path_or_mss (Union[Path, List[MS, ...]]): Path to search or existing MSs
expected_ms_count (int): Expected number of MSs to find
data_column (str): The data column to nominate if creating MSs after searching
Raises:
FrequencyMismatchError: Raised when frequency information is not consistent
Returns:
Tuple[MS, ...]: Collection of MSs
"""
if isinstance(science_path_or_mss, (list, tuple)):
logger.info("Already loaded MSs")
return tuple(sms for sms in science_path_or_mss)
# Find the MSs
# - optionally untar?
science_mss = find_mss_to_image(
mss_parent_path=science_path_or_mss,
expected_ms_count=expected_ms_count,
data_column=data_column,
)
# 2 - ensure matchfing frequencies over channels
consistent_frequencies_across_mss = consistent_ms_frequencies(mss=science_mss)
if not consistent_frequencies_across_mss:
logger.critical("Mismatch in frequencies among provided MSs")
raise FrequencyMismatchError("There is a mismatch in frequencies")
return science_mss
@task
[docs]
def task_map_all_wsclean(in_mss: list[MS], *args, **kwargs) -> list[WSCleanResult]:
"""A single task that internally runs the wsclean imager task while iterates over
the input list of measurement sets when run. It performs the same operation as the prefect
task enabled `map` operator, but unlike the `.map` method all outputs are tracked in a single
prefect task. This is used to lower the load on the prefect server.
Args:
in_mss (list[MS]): List of measurement sets to image
Returns:
list[WSCleanResult]: The list of output wsclean results
"""
wsclean_results = []
for ms in in_mss:
logger.info(f"Imaging {ms.path}")
wsclean_results.append(task_wsclean_imager.fn(in_ms=ms, **kwargs))
return wsclean_results
@task
[docs]
def task_combine_all_linmos_images(
linmos_commands: list[LinmosResult],
remove_original_images: bool = False,
combine_weights: bool = False,
time_domain: bool = False,
update_fits_cube_options: dict[str, Any] | None = None,
) -> Path:
"""Use the fits-cube package to take all input images and create a single output cube.
Args:
linmos_commands (list[LinmosResult]): The output linmos commands to concatenated into a single cube.
remove_original_images (bool, optional): Remove the original images after the cube has been formed. Defaults to False.
combine_weights (bool, optional): Whether to concatenated the images or the weights that are described by the input `linmos_commands`. Defaults to False.
time_domain (bool, optional): Whether images are to be formed on the spectral or time axis. Defaults to False.
bounding_box (bool, optional): Whether to trim the output cube to include only valid pixels (see fitscube docs). Defaults to False.
invalidate_zeros (bool, optional): Where to mark pixels that are exactly zero as invalid (replace with a NaN). Defaults to False.
Returns:
Path: The output cube path
"""
output_cube_path = Path("test.fits")
fits_cube_options = FitsCubeOptions()
if update_fits_cube_options is not None:
fits_cube_options = fits_cube_options.with_options(**update_fits_cube_options)
if combine_weights:
logger.info("Combining weight fits files")
images_to_combine = [
linmos_command.weight_fits for linmos_command in linmos_commands
]
output_suffix = "weight"
else:
logger.info("Combining image fits files")
images_to_combine = [
linmos_command.image_fits for linmos_command in linmos_commands
]
output_suffix = "linmos"
logger.info(f"Combining {len(images_to_combine)} FITS files together")
output_suffix = f"time.{output_suffix}" if time_domain else f"freq.{output_suffix}"
from flint.naming import create_image_cube_name, create_name_from_common_fields
assert len(images_to_combine) > 0, "No images to combine"
base_cube_path = create_name_from_common_fields(in_paths=tuple(images_to_combine))
output_cube_path = create_image_cube_name(
image_prefix=base_cube_path, mode="contsub", suffix=output_suffix
)
_ = combine_fits(
file_list=images_to_combine,
out_cube=output_cube_path,
max_workers=fits_cube_options.max_workers,
time_domain_mode=time_domain,
bounding_box=fits_cube_options.bounding_box,
invalidate_zeros=fits_cube_options.invalidate_zeros,
)
if remove_original_images:
logger.info(f"Removing original {len(images_to_combine)} images")
for image in images_to_combine:
logger.info(f"Removing {image=}")
assert isinstance(image, Path) and image.exists(), (
f"{image=} does not exist, but it should"
)
image.unlink()
return Path(output_cube_path)
@flow
[docs]
def flow_addmodel_to_mss(
science_path_or_mss: Path | tuple[MS, ...],
addmodel_subtract_field_options: AddModelSubtractFieldOptions,
expected_ms: int,
data_column: str,
) -> tuple[MS, ...]:
"""Separate flow to perform the potentially expensive model prediction
into MSs"""
_check_and_verify_options(
addmodel_subtract_field_options=addmodel_subtract_field_options
)
# Get the MSs that will have their model added to
science_mss = find_and_setup_mss(
science_path_or_mss=science_path_or_mss,
expected_ms_count=expected_ms,
data_column=data_column,
)
science_mss = task_addmodel_to_ms.map(
ms=science_mss,
addmodel_subtract_options=unmapped(addmodel_subtract_field_options),
)
return science_mss
@flow
[docs]
def flow_subtract_cube(
science_path: Path,
subtract_field_options: SubtractFieldOptions,
addmodel_subtract_field_options: AddModelSubtractFieldOptions,
crystalball_subtract_field_options: CrystalBallOptions,
) -> None:
strategy = load_and_copy_strategy(
output_split_science_path=science_path,
imaging_strategy=subtract_field_options.imaging_strategy,
)
_check_and_verify_options(
subtract_field_options=subtract_field_options,
addmodel_subtract_field_options=addmodel_subtract_field_options,
crystalball_subtract_field_options=crystalball_subtract_field_options,
)
# Find the MSs
# - optionally untar?
science_mss = find_and_setup_mss(
science_path_or_mss=science_path,
expected_ms_count=subtract_field_options.expected_ms,
data_column=subtract_field_options.data_column,
)
# 2.5 - Continuum subtract if requested
logger.info("Collecting frequencies from representative measurement set")
freqs_mhz = get_freqs_from_ms(ms=science_mss[0]) / 1e6
logger.info(
f"Considering {len(freqs_mhz)} frequencies from {science_mss[0]}, range covering {np.min(freqs_mhz)}-{np.max(freqs_mhz)} MHz"
)
if len(freqs_mhz) > 20 and subtract_field_options.stagger_delay_seconds is None:
logger.critical(
f"{len(freqs_mhz)} channels and no stagger delay set! Consider setting a stagger delay"
)
logger.info("Collecting times from representative measurement set")
times = get_times_from_ms(
ms=science_mss[0], sort=True, unique=True, return_raw=True
)
logger.info(
f"Considering {len(times)} times from {science_mss[0]}, covering {np.min(times)}-{np.max(times)}"
)
if subtract_field_options.use_addmodel:
logger.info(
"Will attempt to predict and subtract continuum model using addmodel"
)
assert addmodel_subtract_field_options.addmodel_cluster_config is not None, (
f"{addmodel_subtract_field_options.addmodel_cluster_config=}, which should not happen"
)
addmodel_dask_runner = get_dask_runner(
cluster=addmodel_subtract_field_options.addmodel_cluster_config
)
science_mss = flow_addmodel_to_mss.with_options(
task_runner=addmodel_dask_runner, name="Predict -- Addmodel"
)(
science_path_or_mss=science_mss,
addmodel_subtract_field_options=addmodel_subtract_field_options,
expected_ms=subtract_field_options.expected_ms,
data_column=subtract_field_options.data_column,
)
if subtract_field_options.use_crystalball:
with tags("crystalball"):
logger.info("Attempting to peer into the crystalball, me'hearty")
science_mss = task_crystalball_to_ms.map(
ms=science_mss,
crystalball_options=unmapped(crystalball_subtract_field_options),
)
with tags("subtract"):
if subtract_field_options.attempt_subtract:
science_mss = task_subtract_model_from_ms.map(
ms=science_mss,
output_column=subtract_field_options.subtract_data_column,
update_tracked_column=True,
chunk_size=1000,
)
elif (
subtract_field_options.subtract_data_column is not None
and not subtract_field_options.subtract_only
):
from flint.prefect.common.ms import task_update_ms_tracked_column
science_mss = task_update_ms_tracked_column.map(
ms=science_mss, column_name=subtract_field_options.subtract_data_column
)
# Fellow Captain Zic request, arrr
if subtract_field_options.subtract_only:
logger.info("The '--subtract-only' option has been specified. No imaging.")
return
if subtract_field_options.channelwise_image:
with tags("channel-cube"):
channel_parset_list = []
for channel, freq_mhz in enumerate(freqs_mhz):
if channel >= subtract_field_options.max_intervals:
logger.critical("Breaking for sanity")
break
logger.info(f"Imaging {channel=} {freq_mhz=}")
channel_range = (channel, channel + 1)
update_wsclean_options = get_options_from_strategy(
strategy=strategy,
mode="wsclean",
operation="subtractcube",
)
channel_wsclean_cmds = task_map_all_wsclean.submit(
in_mss=science_mss,
wsclean_container=subtract_field_options.wsclean_container,
channel_range=channel_range,
update_wsclean_options=update_wsclean_options,
)
channel_parset = task_common_beam_convolve_linmos.submit(
wsclean_results=channel_wsclean_cmds,
linmos_suffix_str=None,
field_options=subtract_field_options,
convol_mode="image",
convol_filter="image.",
convol_suffix_str="optimal.conv",
trim_linmos_fits=False, # This is necessary to ensure all images have same pixel-coordinates
remove_original_images=True,
cleanup_linmos=True,
)
channel_parset_list.append(channel_parset)
if subtract_field_options.stagger_delay_seconds:
sleep(subtract_field_options.stagger_delay_seconds)
# 4 - cube concatenated each linmos field together to single file
update_fits_cube_options = get_options_from_strategy(
strategy=strategy,
mode="fitscube",
operation="subtractcube",
)
task_combine_all_linmos_images.submit(
linmos_commands=channel_parset_list,
remove_original_images=subtract_field_options.fitscube_remove_original_images,
update_fits_cube_options=update_fits_cube_options,
)
task_combine_all_linmos_images.submit(
linmos_commands=channel_parset_list,
remove_original_images=subtract_field_options.fitscube_remove_original_images,
combine_weights=True,
update_fits_cube_options=update_fits_cube_options,
)
if subtract_field_options.timestep_image:
with tags("time-cube"):
scan_parset_list = []
for scan, time in enumerate(times):
if scan >= subtract_field_options.max_intervals:
logger.critical("Breaking for sanity")
break
logger.info(f"Imaging {scan=} {time=}")
scan_range = (scan, scan + 1)
update_wsclean_options = get_options_from_strategy(
strategy=strategy,
mode="wsclean",
operation="subtractcube",
)
scan_wsclean_cmds = task_map_all_wsclean.submit(
in_mss=science_mss,
wsclean_container=subtract_field_options.wsclean_container,
scan_range=scan_range,
update_wsclean_options=update_wsclean_options,
)
scan_parset = task_common_beam_convolve_linmos.submit(
wsclean_results=scan_wsclean_cmds,
linmos_suffix_str=None,
field_options=subtract_field_options,
convol_mode="image",
convol_filter="image.",
convol_suffix_str="optimal.conv",
trim_linmos_fits=False, # This is necessary to ensure all images have same pixel-coordinates
remove_original_images=True,
cleanup_linmos=True,
)
scan_parset_list.append(scan_parset)
if subtract_field_options.stagger_delay_seconds:
sleep(subtract_field_options.stagger_delay_seconds)
# 4 - cube concatenated each linmos field together to single file
update_fits_cube_options = get_options_from_strategy(
strategy=strategy,
mode="fitscube",
operation="subtractcube",
)
task_combine_all_linmos_images.submit(
linmos_commands=scan_parset_list,
remove_original_images=subtract_field_options.fitscube_remove_original_images,
time_domain=True,
update_fits_cube_options=update_fits_cube_options,
)
task_combine_all_linmos_images.submit(
linmos_commands=scan_parset_list,
remove_original_images=subtract_field_options.fitscube_remove_original_images,
combine_weights=True,
time_domain=True,
update_fits_cube_options=update_fits_cube_options,
)
return
[docs]
def setup_run_subtract_flow(
science_path: Path,
subtract_field_options: SubtractFieldOptions,
addmodel_subtract_field_options: AddModelSubtractFieldOptions,
crystalball_subtract_field_options: CrystalBallOptions,
cluster_config: Path,
) -> None:
logger.info(f"Processing {science_path=}")
science_sbid = get_sbid_from_path(path=science_path)
dask_runner = get_dask_runner(cluster=cluster_config)
flow_subtract_cube.with_options(
task_runner=dask_runner, name=f"Subtract Cube Pipeline -- {science_sbid}"
)(
science_path=science_path,
subtract_field_options=subtract_field_options,
addmodel_subtract_field_options=addmodel_subtract_field_options,
crystalball_subtract_field_options=crystalball_subtract_field_options,
)
[docs]
def get_parser() -> ArgumentParser:
parser = ArgumentParser(description=__doc__)
parser.add_argument(
"--cli-config", is_config_file=True, help="Path to configuration file"
)
parser.add_argument(
"science_path",
type=Path,
help="Path to the directory containing the beam-wise measurement sets",
)
parser.add_argument(
"--cluster-config",
type=str,
default="petrichor",
help="Path to a cluster configuration file, or a known cluster name. ",
)
parser = add_options_to_parser(parser=parser, options_class=SubtractFieldOptions)
parser = add_options_to_parser(
parser=parser, options_class=AddModelSubtractFieldOptions
)
parser = add_options_to_parser(parser=parser, options_class=CrystalBallOptions)
return parser
[docs]
def cli() -> None:
parser = get_parser()
args = parser.parse_args()
subtract_field_options = create_options_from_parser(
parser_namespace=args, options_class=SubtractFieldOptions
)
addmodel_subtract_field_options = create_options_from_parser(
parser_namespace=args, options_class=AddModelSubtractFieldOptions
)
crystalball_options = create_options_from_parser(
parser_namespace=args, options_class=CrystalBallOptions
)
if addmodel_subtract_field_options.addmodel_cluster_config is None:
addmodel_subtract_field_options.with_options(
addmodel_cluster_config=args.cluster_config
)
setup_run_subtract_flow(
science_path=args.science_path,
subtract_field_options=subtract_field_options,
addmodel_subtract_field_options=addmodel_subtract_field_options,
crystalball_subtract_field_options=crystalball_options,
cluster_config=args.cluster_config,
)
if __name__ == "__main__":
cli()