"""
This module allows you to merge xarray Datasets/DataArrays
geospatially with the `rasterio.merge` module.
"""
from collections.abc import Sequence
from typing import Callable, Optional, Union
import numpy
from rasterio.crs import CRS
from rasterio.merge import merge as _rio_merge
from xarray import DataArray, Dataset, IndexVariable
from rioxarray.rioxarray import _get_nonspatial_coords, _make_coords
class RasterioDatasetDuck:
"""
This class is to provide the attributes and methods necessary
to make the :func:`rasterio.merge.merge` function think that
the :obj:`xarray.DataArray` is a :obj:`rasterio.io.DatasetReader`.
"""
# pylint: disable=too-many-instance-attributes
def __init__(self, xds: DataArray):
self._xds = xds
self.crs = xds.rio.crs
self.bounds = xds.rio.bounds(recalc=True)
self.count = int(xds.rio.count)
self.dtypes = [xds.dtype]
self.name = xds.name
self.nodatavals = [xds.rio.nodata]
res = xds.rio.resolution(recalc=True)
self.res = (abs(res[0]), abs(res[1]))
self.transform = xds.rio.transform(recalc=True)
# profile is only used for writing to a file.
# This never happens with rioxarray merge.
self.profile: dict = {}
def colormap(self, *args, **kwargs) -> None:
"""
colormap is only used for writing to a file.
This never happens with rioxarray merge.
"""
# pylint: disable=unused-argument
return None
def read(self, window, out_shape, *args, **kwargs) -> numpy.ma.MaskedArray:
# pylint: disable=unused-argument
"""
This method is meant to be used by the rasterio.merge.merge function.
"""
data_window = self._xds.rio.isel_window(window)
if data_window.shape != out_shape:
# in this section, the data is geographically the same
# however it is not the same dimensions as requested
# so need to resample to the requested shape
if len(out_shape) == 3:
_, out_height, out_width = out_shape
else:
out_height, out_width = out_shape
data_window = self._xds.rio.reproject(
self._xds.rio.crs,
transform=self.transform,
shape=(out_height, out_width),
)
nodata = self.nodatavals[0]
mask = False
fill_value = None
if nodata is not None and numpy.isnan(nodata):
mask = numpy.isnan(data_window)
elif nodata is not None:
mask = data_window == nodata
fill_value = nodata
# make sure the returned shape matches
# the expected shape. This can be the case
# when the xarray dataset was squeezed to 2D beforehand
if len(out_shape) == 3 and len(data_window.shape) == 2:
data_window = data_window.values.reshape((1, out_height, out_width))
return numpy.ma.array(
data_window, mask=mask, fill_value=fill_value, dtype=self.dtypes[0]
)
[docs]
def merge_arrays(
dataarrays: Sequence[DataArray],
*,
bounds: Optional[tuple] = None,
res: Optional[tuple] = None,
nodata: Optional[float] = None,
precision: Optional[float] = None,
method: Union[str, Callable, None] = None,
crs: Optional[CRS] = None,
parse_coordinates: bool = True,
) -> DataArray:
"""
Merge data arrays geospatially.
Uses :func:`rasterio.merge.merge`
.. versionadded:: 0.2 crs
Parameters
----------
dataarrays: list[xarray.DataArray]
List of multiple xarray.DataArray with all geo attributes.
The first one is assumed to have the same
CRS, dtype, and dimensions as the others in the array.
bounds: tuple, optional
Bounds of the output image (left, bottom, right, top).
If not set, bounds are determined from bounds of input DataArrays.
res: tuple, optional
Output resolution in units of coordinate reference system.
If not set, the resolution of the first DataArray is used.
If a single value is passed, output pixels will be square.
nodata: float, optional
nodata value to use in output file.
If not set, uses the nodata value in the first input DataArray.
precision: float, optional
Number of decimal points of precision when computing inverse transform.
method: str or callable, optional
See :func:`rasterio.merge.merge` for details.
crs: rasterio.crs.CRS, optional
Output CRS. If not set, the CRS of the first DataArray is used.
parse_coordinates: bool, optional
If False, it will disable loading spatial coordinates.
Returns
-------
:obj:`xarray.DataArray`:
The geospatially merged data.
"""
input_kwargs = {
"bounds": bounds,
"res": res,
"nodata": nodata,
"precision": precision,
"method": method,
}
if crs is None:
crs = dataarrays[0].rio.crs
if res is None:
res = tuple(abs(res_val) for res_val in dataarrays[0].rio.resolution())
# prepare the duck arrays
rioduckarrays = []
for dataarray in dataarrays:
da_res = tuple(abs(res_val) for res_val in dataarray.rio.resolution())
if da_res != res or dataarray.rio.crs != crs:
rioduckarrays.append(
RasterioDatasetDuck(
dataarray.rio.reproject(dst_crs=crs, resolution=res)
)
)
else:
rioduckarrays.append(RasterioDatasetDuck(dataarray))
# use rasterio to merge
merged_data, merged_transform = _rio_merge(
rioduckarrays,
**{key: val for key, val in input_kwargs.items() if val is not None},
)
# generate merged data array
representative_array = rioduckarrays[0]._xds
if parse_coordinates:
coords = _make_coords(
src_data_array=representative_array,
dst_affine=merged_transform,
dst_width=merged_data.shape[-1],
dst_height=merged_data.shape[-2],
)
if (
representative_array.rio.x_dim != "x"
and "x" in coords
and coords["x"].ndim == 1
):
coords[representative_array.rio.x_dim] = IndexVariable(
representative_array.rio.x_dim, coords.pop("x")
)
if (
representative_array.rio.y_dim != "y"
and "y" in coords
and coords["y"].ndim == 1
):
coords[representative_array.rio.y_dim] = IndexVariable(
representative_array.rio.y_dim, coords.pop("y")
)
else:
coords = _get_nonspatial_coords(representative_array)
# make sure the output merged data shape is 2D if the
# original data was 2D. this can happen if the
# xarray datasarray was squeezed.
if len(merged_data.shape) == 3 and len(representative_array.shape) == 2:
merged_data = merged_data.squeeze()
xda = DataArray(
name=representative_array.name,
data=merged_data,
coords=coords,
dims=tuple(representative_array.dims),
attrs=representative_array.attrs,
)
xda.encoding = representative_array.encoding.copy()
xda.rio.write_nodata(
nodata if nodata is not None else representative_array.rio.nodata, inplace=True
)
xda.rio.write_crs(
representative_array.rio.crs,
grid_mapping_name=representative_array.rio.grid_mapping,
inplace=True,
)
xda.rio.write_transform(
merged_transform,
grid_mapping_name=representative_array.rio.grid_mapping,
inplace=True,
)
return xda
[docs]
def merge_datasets(
datasets: Sequence[Dataset],
*,
bounds: Optional[tuple] = None,
res: Optional[tuple] = None,
nodata: Optional[float] = None,
precision: Optional[float] = None,
method: Union[str, Callable, None] = None,
crs: Optional[CRS] = None,
) -> Dataset:
"""
Merge datasets geospatially.
Uses :func:`rasterio.merge.merge`
.. versionadded:: 0.2 crs
Parameters
----------
datasets: list[xarray.Dataset]
List of multiple xarray.Dataset with all geo attributes.
The first one is assumed to have the same
CRS, dtype, dimensions, and data_vars as the others in the array.
bounds: tuple, optional
Bounds of the output image (left, bottom, right, top).
If not set, bounds are determined from bounds of input Dataset.
res: tuple, optional
Output resolution in units of coordinate reference system.
If not set, the resolution of the first Dataset is used.
If a single value is passed, output pixels will be square.
nodata: float, optional
nodata value to use in output file.
If not set, uses the nodata value in the first input Dataset.
precision: float, optional
Number of decimal points of precision when computing inverse transform.
method: str or callable, optional
See rasterio docs.
crs: rasterio.crs.CRS, optional
Output CRS. If not set, the CRS of the first DataArray is used.
Returns
-------
:obj:`xarray.Dataset`:
The geospatially merged data.
"""
representative_ds = datasets[0]
merged_data = {}
for iii, data_var in enumerate(representative_ds.data_vars):
merged_data[data_var] = merge_arrays(
[dataset[data_var] for dataset in datasets],
bounds=bounds,
res=res,
nodata=nodata,
precision=precision,
method=method,
crs=crs,
parse_coordinates=iii == 0,
)
data_var = list(representative_ds.data_vars)[0]
xds = Dataset(
merged_data,
attrs=representative_ds.attrs,
)
xds.rio.write_crs(merged_data[data_var].rio.crs, inplace=True)
return xds