fit_predict#

geowombat.ml.fit_predict(data, clf, labels=None, col=None, targ_name='targ', targ_dim_name='sample', mask_nodataval=True)#

Fits a classifier given class labels and predicts on a DataArray.

Parameters:
  • data (DataArray) – The data to predict on.

  • clf (object) – The classifier or classification pipeline.

  • labels (optional[str | Path | GeoDataFrame]) – Class labels as polygon geometry.

  • col (Optional[str]) – The column in labels you want to assign values from. If None, creates a binary raster.

  • targ_name (Optional[str]) – The target name.

  • targ_dim_name (Optional[str]) – The target coordinate name.

  • mask_nodataval (Optional[Bool]) – If true, data.attrs[“nodatavals”][0] are replaced with np.nan and the array is returned as type float

Returns:

Predictions shaped (‘time’ x ‘band’ x ‘y’ x ‘x’)

Return type:

xarray.DataArray

Example

>>> import geowombat as gw
>>> from geowombat.data import l8_224078_20200518, l8_224078_20200518_polygons
>>> from geowombat.ml import fit_predict
>>>
>>> import geopandas as gpd
>>> from sklearn_xarray.preprocessing import Featurizer
>>> from sklearn.pipeline import Pipeline
>>> from sklearn.preprocessing import StandardScaler, LabelEncoder
>>> from sklearn.decomposition import PCA
>>> from sklearn.naive_bayes import GaussianNB
>>> from sklearn.cluster import KMeans
>>>
>>> le = LabelEncoder()
>>>
>>> labels = gpd.read_file(l8_224078_20200518_polygons)
>>> labels['lc'] = le.fit(labels.name).transform(labels.name)
>>>
>>> # Use a supervised classification pipeline
>>> pl = Pipeline([('scaler', StandardScaler()),
>>>                ('pca', PCA()),
>>>                ('clf', GaussianNB()))])
>>>
>>> with gw.open(l8_224078_20200518, nodata=0) as src:
>>>     y = fit_predict(src, pl, labels, col='lc')
>>>     y.isel(time=0).sel(band='targ').gw.imshow()
>>>
>>> with gw.open([l8_224078_20200518,l8_224078_20200518], nodata=0) as src:
>>>     y = fit_predict(src, pl, labels, col='lc')
>>>     y.isel(time=1).sel(band='targ').gw.imshow()
>>>
>>> # Use an unsupervised classification pipeline
>>> cl = Pipeline([('pca', PCA()),
>>>                ('cst', KMeans()))])
>>> with gw.open(l8_224078_20200518, nodata=0) as src:
>>>     y2 = fit_predict(src, cl)