Source code for flint.prefect.flows.polarisation_pipeline

from __future__ import annotations

from pathlib import Path

from configargparse import ArgumentParser
from prefect import flow, tags, unmapped
from prefect.futures import PrefectFuture

from flint.coadd.linmos import LinmosOptions, LinmosResult
from flint.configuration import (
    POLARISATION_MAPPING,
    get_options_from_strategy,
    load_and_copy_strategy,
)
from flint.exceptions import MSError
from flint.imager.wsclean import (
    ImageSet,
    WSCleanResult,
)
from flint.logging import logger
from flint.ms import find_mss
from flint.naming import (
    CASDANameComponents,
    ProcessedNameComponents,
    add_timestamp_to_path,
    extract_components_from_name,
    get_sbid_from_path,
)
from flint.options import (
    PolFieldOptions,
    add_options_to_parser,
    create_options_from_parser,
    dump_field_options_to_yaml,
)
from flint.prefect.clusters import get_dask_runner
from flint.prefect.common.imaging import (
    task_combine_images_to_cube,
    task_convolve_images,
    task_get_channel_images_from_paths,
    task_get_common_beam_from_image_set,
    task_linmos_images,
    task_merge_image_sets,
    task_preprocess_askap_ms,
    task_split_and_get_image_set,
    task_wsclean_imager,
)
from flint.prefect.common.utils import (
    task_create_field_summary,
    task_create_object,
    task_getattr,
    task_rename_linear_to_stokes,
)


@flow(name="Flint Polarisation Pipeline")
[docs] def process_science_fields_pol( flint_ms_directory: Path, pol_field_options: PolFieldOptions, ) -> None: strategy = load_and_copy_strategy( output_split_science_path=flint_ms_directory, imaging_strategy=pol_field_options.imaging_strategy, ) logger.info(f"{pol_field_options=}") if strategy is None: logger.info("No strategy provided. Returning.") return # Get some placeholder names science_mss = list( find_mss( mss_parent_path=flint_ms_directory, expected_ms_count=pol_field_options.expected_ms, data_column=strategy["defaults"].get("data_column", "DATA"), ) ) # Check if MSs have been processed by Flint or have been provided by CASDA from_flint_list = [ isinstance(extract_components_from_name(ms.path), ProcessedNameComponents) for ms in science_mss ] from_casda_list = [ isinstance(extract_components_from_name(ms.path), CASDANameComponents) for ms in science_mss ] if not any(from_flint_list) and not any(from_casda_list): raise MSError("No valid MeasurementSets found! Data must be calibrated first.") if any(from_flint_list) and any(from_casda_list): raise MSError("Cannot mix Flint-processed and CASDA-provided MeasurementSets!") if any(from_casda_list): assert all(from_casda_list), ( "Some MeasurementSets are from Flint, some are from CASDA" ) logger.info("Data are from CASDA, need to apply FixMS") if pol_field_options.casa_container is None: msg = "We need to apply FixMS to CASDA-provided data, but no CASA container provided" raise MSError(msg) corrected_mss = [] for ms in science_mss: corrected_ms = task_preprocess_askap_ms.submit( ms=ms, data_column=strategy["defaults"].get("data_column", "DATA"), skip_rotation=False, fix_stokes_factor=True, apply_ms_transform=True, casa_container=pol_field_options.casa_container, rename=True, ) corrected_mss.append(corrected_ms) assert len(corrected_mss) == len(science_mss), ( "Number of corrected MSs does not match number of input MSs" ) science_mss = corrected_mss field_summary = task_create_field_summary.submit( mss=science_mss, holography_path=pol_field_options.holofile, ) dump_field_options_to_yaml( output_path=add_timestamp_to_path( input_path=flint_ms_directory / "pol_field_options.yaml" ), field_options=pol_field_options, ) logger.info(f"Found the following calibrated measurement sets: {science_mss}") if pol_field_options.wsclean_container is None: logger.info("No wsclean container provided. Returning. ") return polarisations: dict[str, str] = strategy.get("polarisation", {"total": {}}) image_sets_dict: dict[str, PrefectFuture[ImageSet]] = {} image_sets_list: list[PrefectFuture[ImageSet]] = [] for polarisation in polarisations.keys(): _image_sets = [] with tags(f"polarisation-{polarisation}"): for science_ms in science_mss: wsclean_result: PrefectFuture[WSCleanResult] = ( task_wsclean_imager.submit( in_ms=science_ms, wsclean_container=pol_field_options.wsclean_container, make_cube_from_subbands=False, # We will do this later update_wsclean_options=unmapped( get_options_from_strategy( strategy=strategy, operation="polarisation", mode="wsclean", polarisation=polarisation, ) ), ) ) _image_set: PrefectFuture[ImageSet] = task_getattr.submit( wsclean_result, "image_set" ) _image_sets.append(_image_set) image_sets_list.append(_image_set) image_sets_dict[polarisation] = _image_sets merged_image_set = task_merge_image_sets.submit(image_sets=image_sets_list) common_beam_shape = task_get_common_beam_from_image_set.submit( image_set=merged_image_set, cutoff=pol_field_options.beam_cutoff, fixed_beam_shape=pol_field_options.fixed_beam_shape, ) stokes_beam_cubes: dict[str, list[PrefectFuture[Path]]] = {} for polarisation, image_set_list in image_sets_dict.items(): with tags(f"polarisation-{polarisation}"): # Get the individual Stokes parameters in case of joint imaging if polarisation not in POLARISATION_MAPPING.keys(): raise ValueError(f"Unknown polarisation {polarisation}") stokes_list = list(POLARISATION_MAPPING[polarisation]) for stokes in stokes_list: with tags(f"stokes-{stokes}"): beam_cubes: list[PrefectFuture[Path]] = [] for image_set in image_set_list: stokes_image_list = task_split_and_get_image_set.submit( image_set=image_set, get=stokes, by="pol", mode="image", ) convolved_image_list = task_convolve_images.submit( image_paths=stokes_image_list, beam_shape=common_beam_shape, cutoff=pol_field_options.beam_cutoff, ) # TODO: Consider accerating this by doing a linmos per-channel, then combining channel_image_list = task_get_channel_images_from_paths.submit( paths=convolved_image_list ) prefix = task_getattr.submit(image_set, "prefix") if polarisation == "linear": # Get single Stokes prefix - the original prefix is the linear prefix # i.e. `.qu.` -> `.q.` or `.u.` depending on the stokes prefix = task_rename_linear_to_stokes.submit( linear_name=prefix, stokes=stokes, ) cube_path = task_combine_images_to_cube.submit( images=channel_image_list, prefix=prefix, mode="image", remove_original_images=True, ) beam_cubes.append(cube_path) stokes_beam_cubes[stokes] = beam_cubes linmos_result_list: list[PrefectFuture[LinmosResult]] = [] # We run linmos now to ensure we have Stokes I images for leakage correction # If we have not imaged Stokes I, we cannot do leakage correction force_remove_leakage: bool | None = None if "i" not in stokes_beam_cubes.keys(): force_remove_leakage = False linmos_options = task_create_object( object=LinmosOptions, holofile=pol_field_options.holofile, cutoff=pol_field_options.pb_cutoff, stokesi_images=stokes_beam_cubes.get("i"), force_remove_leakage=force_remove_leakage, trim_linmos_fits=pol_field_options.trim_linmos_fits, ) for stokes, beam_cubes in stokes_beam_cubes.items(): with tags(f"stokes-{stokes}"): linmos_result = task_linmos_images.submit( image_list=beam_cubes, container=pol_field_options.yandasoft_container, linmos_options=linmos_options, field_summary=field_summary, ) linmos_result_list.append(linmos_result) # wait for all linmos results to be completed _ = [linmos_result.result() for linmos_result in linmos_result_list]
[docs] def setup_run_process_science_field( cluster_config: str | Path, flint_ms_directory: Path, pol_field_options: PolFieldOptions, ) -> None: science_sbid = get_sbid_from_path(path=flint_ms_directory) if pol_field_options.sbid_copy_path: updated_sbid_copy_path = pol_field_options.sbid_copy_path / f"{science_sbid}" logger.info(f"Updating archive copy path to {updated_sbid_copy_path=}") pol_field_options = pol_field_options.with_options( sbid_copy_path=updated_sbid_copy_path ) dask_task_runner = get_dask_runner(cluster=cluster_config) process_science_fields_pol.with_options( name=f"Flint Polarisation Pipeline - {science_sbid}", task_runner=dask_task_runner, )( flint_ms_directory=flint_ms_directory, pol_field_options=pol_field_options, )
[docs] def get_parser() -> ArgumentParser: parser = ArgumentParser(description=__doc__) parser.add_argument( "--cli-config", is_config_file=True, help="Path to configuration file" ) parser.add_argument( "flint_ms_directory", type=Path, help="Path to directories containing the beam-wise flint-calibrated MeasurementSets.", ) parser.add_argument( "--cluster-config", type=str, default="petrichor", help="Path to a cluster configuration file, or a known cluster name. ", ) parser = add_options_to_parser( parser=parser, options_class=PolFieldOptions, description="Polarisation processing options", ) return parser
[docs] def cli() -> None: import logging # logger = logging.getLogger("flint") logger.setLevel(logging.INFO) parser = get_parser() args = parser.parse_args() field_options = create_options_from_parser( parser_namespace=args, options_class=PolFieldOptions, ) setup_run_process_science_field( cluster_config=args.cluster_config, flint_ms_directory=args.flint_ms_directory, pol_field_options=field_options, )
if __name__ == "__main__": cli()