Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ Changelog

X.Y.Z (DD-MM-YYY)
-----------------
* Support choosing the UVW sign convention (:pr:`49`)
* Add field and source dataset (:pr:`48`)
* Add pre-computed UVW coordinates (:pr:`47`)
* Test that xarray-kat and katdal produce the same data for the same data source (:pr:`45`)
Expand Down
22 changes: 16 additions & 6 deletions src/xarray_kat/datatree_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import warnings
from datetime import datetime, timezone
from importlib.metadata import version as importlib_version
from typing import TYPE_CHECKING, Dict, Iterable, NamedTuple, Set
from typing import TYPE_CHECKING, Dict, Iterable, NamedTuple, Set, get_args

import numpy as np
import tensorstore as ts
Expand All @@ -21,7 +21,7 @@
from xarray_kat.multiton import Multiton
from xarray_kat.stores.vis_weight_flag_store_factory import VisWeightFlagFactory
from xarray_kat.utils import corrprods_to_baseline_pols
from xarray_kat.xkat_types import VanVleckLiteralType
from xarray_kat.xkat_types import UvwSignConventionType, VanVleckLiteralType

if TYPE_CHECKING:
from katpoint import Target
Expand Down Expand Up @@ -79,6 +79,7 @@ class DataTreeFactory:
_scan_states: Set[str]
_applycal: str | Iterable[str]
_van_vleck: VanVleckLiteralType
_uvw_sign_convention: UvwSignConventionType
_endpoint: str
_token: str | None

Expand All @@ -90,6 +91,7 @@ def __init__(
data_products: Multiton[TelstateDataProducts],
applycal: str | Iterable[str],
scan_states: Iterable[str],
uvw_sign_convention: UvwSignConventionType,
van_vleck: VanVleckLiteralType,
endpoint: str,
token: str | None = None,
Expand All @@ -100,6 +102,7 @@ def __init__(
self._data_products = data_products
self._applycal = applycal
self._scan_states = set(scan_states)
self._uvw_sign_convention = uvw_sign_convention
self._van_vleck = van_vleck
self._endpoint = endpoint
self._token = token
Expand Down Expand Up @@ -455,9 +458,16 @@ def WrappedArray(a):
# Measurement Set `definition`_.
# .. _CASA: https://casa.nrao.edu/Memos/CoordConvention.pdf
# .. _definition: https://casa.nrao.edu/Memos/229.html#SECTION00064000000000000000
uvw_coordinates = np.take(uvw_ant, ant1_index, axis=1) - np.take(
uvw_ant, ant2_index, axis=1
)

if self._uvw_sign_convention == "fourier":
uvw = uvw_ant[:, ant2_index, :] - uvw_ant[:, ant1_index, :]
elif self._uvw_sign_convention == "casa":
uvw = uvw_ant[:, ant1_index, :] - uvw_ant[:, ant2_index, :]
else:
raise ValueError(
f"Invalid uvw sign convention {self._uvw_sign_convention} "
f"Should be one of {get_args(UvwSignConventionType)}"
)

flag_p_chunks = data_vars["FLAG"].encoding["preferred_chunks"]
uvw_preferred_chunks = {
Expand All @@ -468,7 +478,7 @@ def WrappedArray(a):

data_vars["UVW"] = Variable(
("time", "baseline_id", "uvw_label"),
uvw_coordinates,
uvw,
{"type": "uvw", "units": "m", "frame": "fk5"},
{"preferred_chunks": uvw_preferred_chunks},
)
Expand Down
7 changes: 6 additions & 1 deletion src/xarray_kat/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from xarray_kat.datatree_factory import DataTreeFactory
from xarray_kat.katdal_types import TelstateDataProducts, TelstateDataSource
from xarray_kat.multiton import Multiton
from xarray_kat.xkat_types import VanVleckLiteralType
from xarray_kat.xkat_types import UvwSignConventionType, VanVleckLiteralType


class KatStore(AbstractDataStore):
Expand All @@ -33,6 +33,7 @@ class KatEntryPoint(BackendEntrypoint):
"scan_states",
"capture_block_id",
"stream_name",
"uvw_sign_convention",
"van_vleck",
]
description = "Opens a MeerKAT data source"
Expand Down Expand Up @@ -84,6 +85,7 @@ def open_datatree(
scan_states: Iterable[str] = ("scan", "track"),
capture_block_id: str | None = None,
stream_name: str | None = None,
uvw_sign_convention: UvwSignConventionType = "casa",
van_vleck: VanVleckLiteralType = "off",
):
group_dicts = self.open_groups_as_dict(
Expand All @@ -94,6 +96,7 @@ def open_datatree(
scan_states=scan_states,
capture_block_id=capture_block_id,
stream_name=stream_name,
uvw_sign_convention=uvw_sign_convention,
van_vleck=van_vleck,
)
return DataTree.from_dict(group_dicts)
Expand All @@ -108,6 +111,7 @@ def open_groups_as_dict(
scan_states: Iterable[str] = ("scan", "track"),
capture_block_id: str | None = None,
stream_name: str | None = None,
uvw_sign_convention: UvwSignConventionType = "casa",
van_vleck: VanVleckLiteralType = "off",
) -> Dict[str, Any]:
url = str(filename_or_obj)
Expand Down Expand Up @@ -143,6 +147,7 @@ def open_groups_as_dict(
telstate_data_products,
applycal,
scan_states,
uvw_sign_convention,
van_vleck,
endpoint,
token,
Expand Down
1 change: 1 addition & 0 deletions src/xarray_kat/xkat_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy.typing as npt

VanVleckLiteralType = Literal["off", "autocorr"]
UvwSignConventionType = Literal["fourier", "casa"]


@dataclass(eq=True, unsafe_hash=True, slots=True, repr=True)
Expand Down
15 changes: 11 additions & 4 deletions tests/test_katdal.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import katdal
import numpy as np
import pytest
import xarray
from pytest_httpserver import HTTPServer

Expand All @@ -10,7 +11,10 @@


class TestKatdal:
def test_katdal_mock_server_basic(self, httpserver: HTTPServer, tmp_path):
@pytest.mark.parametrize("uvw_sign_convention", ["fourier", "casa"])
def test_katdal_mock_server_basic(
self, httpserver: HTTPServer, uvw_sign_convention, tmp_path
):
"""Tests that xarray-kat and katdal return the same data from the same datasource"""
obs = SyntheticObservation("1234567890", ntime=8, nfreq=16, nants=4)
obs.add_scan(range(0, 8), "track", "PKS1934")
Expand All @@ -24,7 +28,9 @@ def test_katdal_mock_server_basic(self, httpserver: HTTPServer, tmp_path):
rdb_url = f"{base_url}1234567890/1234567890_sdp_l0.full.rdb"

ds = katdal.open(rdb_url)
dt = xarray.open_datatree(rdb_url, engine="xarray-kat")
dt = xarray.open_datatree(
rdb_url, engine="xarray-kat", uvw_sign_convention=uvw_sign_convention
)

def reorder_katdal_data(data):
return (
Expand All @@ -48,6 +54,7 @@ def reorder_katdal_data(data):
np.testing.assert_allclose(xarray_kat_flags, katdal_flags)

xarray_kat_uvw = dt[children[0]].UVW.data
# Flip katdal sign convention to match CASA
katdal_uvw = np.stack([-ds.u, -ds.v, -ds.w], axis=2)[:, obs.corrprod_argsort]
katdal_uvw = np.stack([ds.u, ds.v, ds.w], axis=2)[:, obs.corrprod_argsort]
if uvw_sign_convention == "casa":
katdal_uvw = -katdal_uvw
np.testing.assert_allclose(xarray_kat_uvw, katdal_uvw[:, :: obs.npol])