diff --git a/CHANGELOG.rst b/CHANGELOG.rst index bc52683..6fd9dc6 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -3,6 +3,7 @@ Changelog X.Y.Z (DD-MM-YYY) ----------------- +* Support choosing the UVW sign convention (:pr:`49`) * Add field and source dataset (:pr:`48`) * Add pre-computed UVW coordinates (:pr:`47`) * Test that xarray-kat and katdal produce the same data for the same data source (:pr:`45`) diff --git a/src/xarray_kat/datatree_factory.py b/src/xarray_kat/datatree_factory.py index b356c76..491ee18 100644 --- a/src/xarray_kat/datatree_factory.py +++ b/src/xarray_kat/datatree_factory.py @@ -5,7 +5,7 @@ import warnings from datetime import datetime, timezone from importlib.metadata import version as importlib_version -from typing import TYPE_CHECKING, Dict, Iterable, NamedTuple, Set +from typing import TYPE_CHECKING, Dict, Iterable, NamedTuple, Set, get_args import numpy as np import tensorstore as ts @@ -21,7 +21,7 @@ from xarray_kat.multiton import Multiton from xarray_kat.stores.vis_weight_flag_store_factory import VisWeightFlagFactory from xarray_kat.utils import corrprods_to_baseline_pols -from xarray_kat.xkat_types import VanVleckLiteralType +from xarray_kat.xkat_types import UvwSignConventionType, VanVleckLiteralType if TYPE_CHECKING: from katpoint import Target @@ -79,6 +79,7 @@ class DataTreeFactory: _scan_states: Set[str] _applycal: str | Iterable[str] _van_vleck: VanVleckLiteralType + _uvw_sign_convention: UvwSignConventionType _endpoint: str _token: str | None @@ -90,6 +91,7 @@ def __init__( data_products: Multiton[TelstateDataProducts], applycal: str | Iterable[str], scan_states: Iterable[str], + uvw_sign_convention: UvwSignConventionType, van_vleck: VanVleckLiteralType, endpoint: str, token: str | None = None, @@ -100,6 +102,7 @@ def __init__( self._data_products = data_products self._applycal = applycal self._scan_states = set(scan_states) + self._uvw_sign_convention = uvw_sign_convention self._van_vleck = van_vleck self._endpoint = endpoint self._token = token @@ -455,9 +458,16 @@ def WrappedArray(a): # Measurement Set `definition`_. # .. _CASA: https://casa.nrao.edu/Memos/CoordConvention.pdf # .. _definition: https://casa.nrao.edu/Memos/229.html#SECTION00064000000000000000 - uvw_coordinates = np.take(uvw_ant, ant1_index, axis=1) - np.take( - uvw_ant, ant2_index, axis=1 - ) + + if self._uvw_sign_convention == "fourier": + uvw = uvw_ant[:, ant2_index, :] - uvw_ant[:, ant1_index, :] + elif self._uvw_sign_convention == "casa": + uvw = uvw_ant[:, ant1_index, :] - uvw_ant[:, ant2_index, :] + else: + raise ValueError( + f"Invalid uvw sign convention {self._uvw_sign_convention} " + f"Should be one of {get_args(UvwSignConventionType)}" + ) flag_p_chunks = data_vars["FLAG"].encoding["preferred_chunks"] uvw_preferred_chunks = { @@ -468,7 +478,7 @@ def WrappedArray(a): data_vars["UVW"] = Variable( ("time", "baseline_id", "uvw_label"), - uvw_coordinates, + uvw, {"type": "uvw", "units": "m", "frame": "fk5"}, {"preferred_chunks": uvw_preferred_chunks}, ) diff --git a/src/xarray_kat/entrypoint.py b/src/xarray_kat/entrypoint.py index 1698cef..efb468e 100644 --- a/src/xarray_kat/entrypoint.py +++ b/src/xarray_kat/entrypoint.py @@ -18,7 +18,7 @@ from xarray_kat.datatree_factory import DataTreeFactory from xarray_kat.katdal_types import TelstateDataProducts, TelstateDataSource from xarray_kat.multiton import Multiton -from xarray_kat.xkat_types import VanVleckLiteralType +from xarray_kat.xkat_types import UvwSignConventionType, VanVleckLiteralType class KatStore(AbstractDataStore): @@ -33,6 +33,7 @@ class KatEntryPoint(BackendEntrypoint): "scan_states", "capture_block_id", "stream_name", + "uvw_sign_convention", "van_vleck", ] description = "Opens a MeerKAT data source" @@ -84,6 +85,7 @@ def open_datatree( scan_states: Iterable[str] = ("scan", "track"), capture_block_id: str | None = None, stream_name: str | None = None, + uvw_sign_convention: UvwSignConventionType = "casa", van_vleck: VanVleckLiteralType = "off", ): group_dicts = self.open_groups_as_dict( @@ -94,6 +96,7 @@ def open_datatree( scan_states=scan_states, capture_block_id=capture_block_id, stream_name=stream_name, + uvw_sign_convention=uvw_sign_convention, van_vleck=van_vleck, ) return DataTree.from_dict(group_dicts) @@ -108,6 +111,7 @@ def open_groups_as_dict( scan_states: Iterable[str] = ("scan", "track"), capture_block_id: str | None = None, stream_name: str | None = None, + uvw_sign_convention: UvwSignConventionType = "casa", van_vleck: VanVleckLiteralType = "off", ) -> Dict[str, Any]: url = str(filename_or_obj) @@ -143,6 +147,7 @@ def open_groups_as_dict( telstate_data_products, applycal, scan_states, + uvw_sign_convention, van_vleck, endpoint, token, diff --git a/src/xarray_kat/xkat_types.py b/src/xarray_kat/xkat_types.py index ce33110..bac2c8a 100644 --- a/src/xarray_kat/xkat_types.py +++ b/src/xarray_kat/xkat_types.py @@ -6,6 +6,7 @@ import numpy.typing as npt VanVleckLiteralType = Literal["off", "autocorr"] +UvwSignConventionType = Literal["fourier", "casa"] @dataclass(eq=True, unsafe_hash=True, slots=True, repr=True) diff --git a/tests/test_katdal.py b/tests/test_katdal.py index 45a6938..a27b595 100644 --- a/tests/test_katdal.py +++ b/tests/test_katdal.py @@ -1,5 +1,6 @@ import katdal import numpy as np +import pytest import xarray from pytest_httpserver import HTTPServer @@ -10,7 +11,10 @@ class TestKatdal: - def test_katdal_mock_server_basic(self, httpserver: HTTPServer, tmp_path): + @pytest.mark.parametrize("uvw_sign_convention", ["fourier", "casa"]) + def test_katdal_mock_server_basic( + self, httpserver: HTTPServer, uvw_sign_convention, tmp_path + ): """Tests that xarray-kat and katdal return the same data from the same datasource""" obs = SyntheticObservation("1234567890", ntime=8, nfreq=16, nants=4) obs.add_scan(range(0, 8), "track", "PKS1934") @@ -24,7 +28,9 @@ def test_katdal_mock_server_basic(self, httpserver: HTTPServer, tmp_path): rdb_url = f"{base_url}1234567890/1234567890_sdp_l0.full.rdb" ds = katdal.open(rdb_url) - dt = xarray.open_datatree(rdb_url, engine="xarray-kat") + dt = xarray.open_datatree( + rdb_url, engine="xarray-kat", uvw_sign_convention=uvw_sign_convention + ) def reorder_katdal_data(data): return ( @@ -48,6 +54,7 @@ def reorder_katdal_data(data): np.testing.assert_allclose(xarray_kat_flags, katdal_flags) xarray_kat_uvw = dt[children[0]].UVW.data - # Flip katdal sign convention to match CASA - katdal_uvw = np.stack([-ds.u, -ds.v, -ds.w], axis=2)[:, obs.corrprod_argsort] + katdal_uvw = np.stack([ds.u, ds.v, ds.w], axis=2)[:, obs.corrprod_argsort] + if uvw_sign_convention == "casa": + katdal_uvw = -katdal_uvw np.testing.assert_allclose(xarray_kat_uvw, katdal_uvw[:, :: obs.npol])