Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 16 additions & 8 deletions gridformat/common/field.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,18 +167,26 @@ class Field {
void _export_to(R& range,
std::span<const T> data,
std::size_t& offset) const {
if constexpr (mdrange_dimension<R> > 1)
if constexpr (mdrange_dimension<R> > 1) {
std::ranges::for_each(range, [&] (std::ranges::range auto& sub_range) {
_export_to(sub_range, data, offset);
});
else
std::ranges::for_each_n(
std::ranges::begin(range),
std::min(Ranges::size(range), data.size() - offset),
[&] (std::ranges::range_reference_t<R> value) {
value = static_cast<std::ranges::range_value_t<R>>(data[offset++]);
}
} else {
// Note: std::vector<bool> breaks the use of `std::ranges::copy` or similar,
// and seems to only work with recent compilers and/or c++23. Therefore, we
// use `std::copy` here, which actually may break for some range iterators?
auto converted_data = data | std::views::drop(offset)
| std::views::take(std::min(Ranges::size(range), data.size() - offset))
| std::views::transform([] (const T& value) {
return static_cast<std::ranges::range_value_t<R>>(value);
});
std::copy(
std::ranges::begin(converted_data),
std::ranges::end(converted_data),
std::ranges::begin(range)
);
offset += Ranges::size(converted_data);
}
}
};

Expand Down
43 changes: 33 additions & 10 deletions test/test_vtk_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def _get_grid_and_space_dimension(filename: str) -> Tuple[int, int]:
return int(dim), int(space_dim)


def _check_vtk_file(vtk_reader,
def _check_vtk_file(vtk_grid,
points,
space_dim,
reference_function: Callable[[list], float],
Expand All @@ -72,9 +72,8 @@ def _check_vtk_file(vtk_reader,
rel_tol = 1e-5
abs_tol = 1e-3

output = vtk_reader.GetOutput()
if not skip_metadata:
field_data = output.GetFieldData()
field_data = vtk_grid.GetFieldData()
expected_field_data = ["literal", "string", "numbers"]
for i in range(field_data.GetNumberOfArrays()):
name = field_data.GetAbstractArray(i).GetName()
Expand All @@ -92,12 +91,12 @@ def _check_vtk_file(vtk_reader,
reference_function.set_time(time_value)

# precompute cell centers
num_cells = output.GetNumberOfCells()
num_cells = vtk_grid.GetNumberOfCells()
points = array(points)
cell_centers = ndarray(shape=(num_cells, 3))
for cell_id in range(num_cells):
ids = vtk.vtkIdList()
output.GetCellPoints(cell_id, ids)
vtk_grid.GetCellPoints(cell_id, ids)
corner_indices = [ids.GetId(_i) for _i in range(ids.GetNumberOfIds())]
cell_centers[cell_id] = np_sum(points[corner_indices], axis=0)
cell_centers[cell_id] /= float(ids.GetNumberOfIds())
Expand All @@ -115,7 +114,7 @@ def _compare_data_array(arr, position_call_back):
else:
assert isclose(0.0, value[comp], rel_tol=rel_tol, abs_tol=abs_tol)

point_data = output.GetPointData()
point_data = vtk_grid.GetPointData()
for i in range(point_data.GetNumberOfArrays()):
name = point_data.GetArrayName(i)
arr = point_data.GetArray(i)
Expand All @@ -125,15 +124,15 @@ def _compare_data_array(arr, position_call_back):
point_id = 0
for cell_id in range(num_cells):
ids = vtk.vtkIdList()
output.GetCellPoints(cell_id, ids)
vtk_grid.GetCellPoints(cell_id, ids)
for _ in range(ids.GetNumberOfIds()):
assert isclose(arr.GetTuple(point_id)[0], float(cell_id))
point_id += 1
else:
for i in range(arr.GetNumberOfTuples()):
_compare_data_array(arr, lambda i: points[i])

cell_data = output.GetCellData()
cell_data = vtk_grid.GetCellData()
for i in range(cell_data.GetNumberOfArrays()):
name = cell_data.GetArrayName(i)
arr = cell_data.GetArray(i)
Expand All @@ -154,16 +153,39 @@ def _read_pvd_pieces(filename: str) -> List[_TimeStep]:


def _test_vtk(filename: str, skip_metadata: bool, reference_function: Callable[[list], float]):
def _grid(reader):
return reader.GetOutput()

def _merged_partitioned_grid(reader):
output = reader.GetOutput()
merged = vtk.vtkAppendFilter()
iter = output.NewIterator()
iter.InitTraversal()
while not iter.IsDoneWithTraversal():
dataset = iter.GetCurrentDataObject()
if dataset:
merged.AddInputData(dataset)
iter.GoToNextItem()
merged.Update()
merged_grid = merged.GetOutput()
merged_grid.GetFieldData().PassData(reader.GetOutput().GetFieldData())
return merged_grid

def _get_points_from_grid(reader):
points = reader.GetOutput().GetPoints()
return array([points.GetPoint(i) for i in range(points.GetNumberOfPoints())])

def _get_points_from_partioned_grid(reader):
points = _merged_partitioned_grid(reader).GetPoints()
return array([points.GetPoint(i) for i in range(points.GetNumberOfPoints())])

def _get_rectilinear_points(reader):
output = reader.GetOutput()
return array([output.GetPoint(i) for i in range(output.GetNumberOfPoints())])

e = VTKErrorObserver()
ext = splitext(filename)[1]
get_grid = _grid
if ext == ".vtu":
reader = vtk.vtkXMLUnstructuredGridReader()
point_collector = _get_points_from_grid
Expand Down Expand Up @@ -199,7 +221,8 @@ def _get_rectilinear_points(reader):
point_collector = _get_rectilinear_points
elif ext == ".hdf" and "unstructured" in filename:
reader = vtk.vtkHDFReader()
point_collector = _get_points_from_grid
point_collector = _get_points_from_partioned_grid
get_grid = _merged_partitioned_grid
else:
raise NotImplementedError(f"Could not determine suitable reader {filename}")
reader.AddObserver("ErrorEvent", e)
Expand All @@ -210,7 +233,7 @@ def _get_rectilinear_points(reader):

_, space_dim = _get_grid_and_space_dimension(filename)
is_discontinuous = "discontinuous" in filename
_check_vtk_file(reader, point_collector(reader), space_dim, reference_function, skip_metadata, is_discontinuous)
_check_vtk_file(get_grid(reader), point_collector(reader), space_dim, reference_function, skip_metadata, is_discontinuous)


def test(filename: str, skip_metadata: bool = False) -> int | None:
Expand Down
Loading