predict#

geowombat.ml.predict(data, X, clf, targ_name='targ', targ_dim_name='sample', mask_nodataval=True, temporal_mode='panel')#

Predicts on a DataArray using a fitted classifier.

Parameters:
  • data (DataArray) – The data to predict on.

  • X (DataArray) – Data array generated by geowombat.ml.fit.

  • clf (object) – The classifier or classification pipeline.

  • targ_name (Optional[str]) – The target name.

  • targ_dim_name (Optional[str]) – The target coordinate name.

  • mask_nodataval (Optional[Bool]) – If true, data.attrs[“nodatavals”][0] are replaced with np.nan and the array is returned as type float

  • temporal_mode (Optional[str]) – How to handle time-dimensioned data. ‘panel’ — each pixel-time is an independent sample. ‘flatten’ — flatten time into band. Must match the temporal_mode used in fit().

Returns:

Predictions shaped (‘time’ x ‘band’ x ‘y’ x ‘x’)

Return type:

xarray.DataArray

Example

>>> import geowombat as gw
>>> from geowombat.data import l8_224078_20200518, l8_224078_20200518_polygons
>>> from geowombat.ml import fit, predict
>>> import geopandas as gpd
>>> from sklearn.pipeline import Pipeline
>>> from sklearn.preprocessing import LabelEncoder, StandardScaler
>>> from sklearn.decomposition import PCA
>>> from sklearn.naive_bayes import GaussianNB
>>> le = LabelEncoder()
>>> labels = gpd.read_file(l8_224078_20200518_polygons)
>>> labels["lc"] = le.fit(labels.name).transform(labels.name)
>>> # Use a data pipeline
>>> pl = Pipeline([('scaler', StandardScaler()),
>>>                ('pca', PCA()),
>>>                ('clf', GaussianNB())])
>>> # Fit and predict the classifier
>>> with gw.config.update(ref_res=100):
>>>     with gw.open(l8_224078_20200518, nodata=0) as src:
>>>         X, Xy, clf = fit(src, pl, labels, col="lc")
>>>         y = predict(src, X, clf)
>>>         print(y)
>>> # Fit and predict an unsupervised classifier
>>> cl = Pipeline([('pca', PCA()),
>>>                ('cst', KMeans())])
>>> with gw.open(l8_224078_20200518) as src:
>>>    X, Xy, clf = fit(src, cl)
>>>    y1 = predict(src, X, clf)