diff --git a/xarray_sql/df.py b/xarray_sql/df.py index a249d10..ef9e4c6 100644 --- a/xarray_sql/df.py +++ b/xarray_sql/df.py @@ -183,18 +183,15 @@ def read_xarray(ds: xr.Dataset, chunks: Chunks = None) -> pa.RecordBatchReader: Returns: A PyArrow Table, which is a table representation of the input Dataset. """ - fst = next(iter(ds.values())).dims - assert all( - da.dims == fst for da in ds.values() - ), "All dimensions must be equal. Please filter data_vars in the Dataset." - - blocks = list(block_slices(ds, chunks)) def pivot_block(b: Block): return pivot(ds.isel(b)) - schema = pa.Schema.from_pandas(pivot_block(blocks[0])) - last_schema = pa.Schema.from_pandas(pivot_block(blocks[-1])) - assert schema == last_schema, "Schemas must be consistent across blocks!" + fst = next(iter(ds.values())).dims + assert all( + da.dims == fst for da in ds.values() + ), "All dimensions must be equal. Please filter data_vars in the Dataset." + schema = _parse_schema(ds) + blocks = block_slices(ds, chunks) return from_map_batched(pivot_block, blocks, schema=schema)