fit#

geowombat.ml.fit(data, clf, labels=None, col=None, targ_name='targ', targ_dim_name='sample', temporal_mode='panel')#

Fits a classifier given class labels.

Parameters:
  • data (DataArray) – The data to predict on.

  • clf (object) – The classifier or classification pipeline.

  • labels (Optional[str | Path | GeoDataFrame]) – Class labels as polygon geometry.

  • col (Optional[str]) – The column in labels you want to assign values from. If None, creates a binary raster.

  • targ_name (Optional[str]) – The target name.

  • targ_dim_name (Optional[str]) – The target coordinate name.

  • temporal_mode (Optional[str]) – How to handle time-dimensioned data. 'panel' — each pixel-time is an independent sample (B features); output has time dimension with per-time predictions. 'flatten' — flatten time into band (T*B features per pixel); output has no time dimension, one prediction per pixel. Ignored when data has no time dimension.

Returns:

Tuple of (X, Xna, clf), where X is the original xarray.DataArray augmented to accept a prediction dimension; Xna is a tuple (xarray.DataArray, sklearn_xarray.Target) of the reshaped feature data (with NAs removed for supervised classifiers, retained for unsupervised) and the target array (None for unsupervised); and clf is the fitted sklearn pipeline.

Example

>>> import geowombat as gw
>>> from geowombat.data import l8_224078_20200518, l8_224078_20200518_polygons
>>> from geowombat.ml import fit
>>>
>>> import geopandas as gpd
>>> from sklearn.pipeline import Pipeline
>>> from sklearn.preprocessing import StandardScaler, LabelEncoder
>>> from sklearn.decomposition import PCA
>>> from sklearn.naive_bayes import GaussianNB
>>>
>>> le = LabelEncoder()
>>>
>>> labels = gpd.read_file(l8_224078_20200518_polygons)
>>> labels['lc'] = le.fit(labels.name).transform(labels.name)
>>>
>>> # Use supervised classification pipeline
>>> pl = Pipeline([('scaler', StandardScaler()),
>>>                ('pca', PCA()),
>>>                ('clf', GaussianNB())])
>>>
>>> with gw.open(l8_224078_20200518) as src:
>>>   X, Xy, clf = fit(src, pl, labels, col='lc')
>>> # Fit an unsupervised classifier
>>> cl = Pipeline([('pca', PCA()),
>>>                ('cst', KMeans())])
>>> with gw.open(l8_224078_20200518) as src:
>>>    X, Xy, clf = fit(src, cl)