Source code for geowombat.util.plotting
import logging
import matplotlib as mpl
import matplotlib.pyplot as plt
import xarray as xr
from ..handler import add_handler
logger = logging.getLogger(__name__)
logger = add_handler(logger)
[docs]class Plotting(object):
[docs] @staticmethod
def imshow(
data,
mask=False,
nodata=0,
flip=False,
text_color='black',
rot=30,
**kwargs
):
"""Shows an image on a plot.
Args:
data (``xarray.DataArray`` or ``xarray.Dataset``): The data to plot.
mask (Optional[bool]): Whether to mask 'no data' values (given by ``nodata``).
nodata (Optional[int or float]): The 'no data' value.
flip (Optional[bool]): Whether to flip an RGB array's band order.
text_color (Optional[str]): The text color.
rot (Optional[int]): The degree rotation for the x-axis tick labels.
kwargs (Optional[dict]): Keyword arguments passed to ``xarray.plot.imshow``.
Returns:
``matplotlib`` axis object
Examples:
>>> import geowombat as gw
>>>
>>> # Open a 3-band image and plot the first band
>>> with gw.open('image.tif') as ds:
>>> ax = ds.sel(band=1).gw.imshow()
>>>
>>> # Open and plot a 3-band image
>>> with gw.open('image.tif') as ds:
>>>
>>> ax = ds.sel(band=['red', 'green', 'blue']).gw.imshow(mask=True,
>>> nodata=0,
>>> vmin=0.1,
>>> vmax=0.9,
>>> robust=True)
"""
if data.gw.nbands != 1:
if data.gw.nbands != 3:
logger.exception(
' Only 1-band or 3-band arrays can be plotted.'
)
plt.rcParams['axes.titlesize'] = 5
plt.rcParams['axes.titlepad'] = 5
plt.rcParams['text.color'] = text_color
plt.rcParams['axes.labelcolor'] = text_color
plt.rcParams['xtick.color'] = text_color
plt.rcParams['ytick.color'] = text_color
plt.rcParams['figure.dpi'] = kwargs['dpi'] if 'dpi' in kwargs else 150
plt.rcParams['savefig.bbox'] = 'tight'
plt.rcParams['savefig.pad_inches'] = 0.5
if 'ax' not in kwargs:
fig, ax = plt.subplots()
kwargs['ax'] = ax
ax = kwargs['ax']
if mask:
if isinstance(data, xr.Dataset):
if data.gw.nbands == 1:
plot_data = data.where(
(data['mask'] < 3) & (data != nodata)
)
else:
plot_data = data.where(
(data['mask'] < 3) & (data.max(axis=0) != nodata)
)
else:
if data.gw.nbands == 1:
plot_data = data.where(data != nodata)
else:
plot_data = data.where(data.max(axis=0) != nodata)
else:
plot_data = data
if plot_data.gw.nbands == 3:
plot_data = plot_data.transpose('y', 'x', 'band')
if flip:
plot_data = plot_data[..., ::-1]
plot_data.plot.imshow(rgb='band', **kwargs)
else:
plot_data.squeeze().plot.imshow(**kwargs)
ax.xaxis.set_major_formatter(mpl.ticker.StrMethodFormatter('{x:,.0f}'))
ax.yaxis.set_major_formatter(mpl.ticker.StrMethodFormatter('{x:,.0f}'))
for tick in ax.get_xticklabels():
tick.set_rotation(rot)
return ax