Machine learning#

Fit a classifier#

In [1]: import geowombat as gw

In [2]: from geowombat.data import l8_224078_20200518, l8_224078_20200518_polygons

In [3]: from geowombat.ml import fit, predict, fit_predict

In [4]: import geopandas as gpd

In [5]: from sklearn.pipeline import Pipeline

In [6]: from sklearn.preprocessing import LabelEncoder, StandardScaler

In [7]: from sklearn.decomposition import PCA

In [8]: from sklearn.naive_bayes import GaussianNB

In [9]: from sklearn.cluster import KMeans

In [10]: le = LabelEncoder()

# The labels are string names, so here we convert them to integers
In [11]: labels = gpd.read_file(l8_224078_20200518_polygons)

In [12]: labels['lc'] = le.fit(labels.name).transform(labels.name)

# Use a data pipeline
In [13]: pl = Pipeline([('scaler', StandardScaler()),
   ....:                 ('pca', PCA()),
   ....:                 ('clf', GaussianNB())])
   ....: 

# Fit the classifier
In [14]: with gw.config.update(ref_res=100):
   ....:     with gw.open(l8_224078_20200518, chunks=128) as src:
   ....:         X, Xy, clf = fit(src, pl, labels, col='lc')
   ....: 

In [15]: print(clf)
Pipeline(steps=[('scaler',
                 EstimatorWrapper(copy=True, estimator=StandardScaler(),
                                  reshapes='band', with_mean=True,
                                  with_std=True)),
                ('pca',
                 EstimatorWrapper(copy=True, estimator=PCA(),
                                  iterated_power='auto', n_components=None,
                                  n_oversamples=10,
                                  power_iteration_normalizer='auto',
                                  random_state=None, reshapes='band',
                                  svd_solver='auto', tol=0.0, whiten=False)),
                ('clf',
                 EstimatorWrapper(estimator=GaussianNB(), priors=None,
                                  reshapes='band', var_smoothing=1e-09))])

Fit a classifier and predict on an array#

In [16]: from geowombat.ml import fit_predict

In [17]: import matplotlib.pyplot as plt

In [18]: fig, ax = plt.subplots(dpi=200)

In [19]: with gw.config.update(ref_res=100):
   ....:     with gw.open(l8_224078_20200518 ) as src:
   ....:         y = fit_predict(src, pl, labels, col='lc')
   ....:         y.plot(robust=True, ax=ax)
   ....:         print(y)
   ....: 
<xarray.DataArray (band: 1, y: 558, x: 612)> Size: 3MB
dask.array<xarray-<this-array>, shape=(1, 558, 612), dtype=float64, chunksize=(1, 256, 256), chunktype=numpy.ndarray>
Coordinates:
  * x        (x) float64 5kB 7.174e+05 7.175e+05 ... 7.784e+05 7.785e+05
  * y        (y) float64 4kB -2.777e+06 -2.777e+06 ... -2.833e+06 -2.833e+06
    time     <U2 8B 't1'
    targ     (y, x) uint8 341kB dask.array<chunksize=(256, 256), meta=np.ndarray>
  * band     (band) <U4 16B 'targ'
Attributes: (12/13)
    transform:           (100.0, 0.0, 717345.0, 0.0, -100.0, -2776995.0)
    crs:                 32621
    res:                 (100.0, 100.0)
    is_tiled:            1
    nodatavals:          (nan, nan, nan)
    _FillValue:          nan
    ...                  ...
    offsets:             (0.0, 0.0, 0.0)
    filename:            /home/docs/checkouts/readthedocs.org/user_builds/geo...
    resampling:          nearest
    AREA_OR_POINT:       Area
    _data_are_separate:  0
    _data_are_stacked:   0

Fit a classifier with multiple dates#

In [20]: with gw.config.update(ref_res=100):
   ....:     with gw.open(
   ....:         [l8_224078_20200518, l8_224078_20200518],
   ....:         time_names=['t1', 't2'],
   ....:         stack_dim='time',
   ....:         chunks=128
   ....:     ) as src:
   ....:         y = fit_predict(src, pl, labels, col='lc')
   ....:         print(y)
   ....: 
<xarray.DataArray (time: 2, band: 1, y: 558, x: 612)> Size: 5MB
dask.array<xarray-<this-array>, shape=(2, 1, 558, 612), dtype=float64, chunksize=(2, 1, 128, 128), chunktype=numpy.ndarray>
Coordinates:
  * x        (x) float64 5kB 7.174e+05 7.175e+05 ... 7.784e+05 7.785e+05
  * y        (y) float64 4kB -2.777e+06 -2.777e+06 ... -2.833e+06 -2.833e+06
  * time     (time) object 16B 't1' 't2'
    targ     (time, y, x) uint8 683kB dask.array<chunksize=(2, 128, 128), meta=np.ndarray>
  * band     (band) <U4 16B 'targ'
Attributes: (12/13)
    transform:           (100.0, 0.0, 717345.0, 0.0, -100.0, -2776995.0)
    crs:                 32621
    res:                 (100.0, 100.0)
    is_tiled:            1
    nodatavals:          (nan, nan, nan)
    _FillValue:          nan
    ...                  ...
    offsets:             (0.0, 0.0, 0.0)
    filename:            ['LC08_L1TP_224078_20200518_20200518_01_RT.TIF', 'LC...
    resampling:          nearest
    AREA_OR_POINT:       Area
    _data_are_separate:  1
    _data_are_stacked:   1

Train a supervised classifier and predict#

In [21]: fig, ax = plt.subplots(dpi=200,figsize=(5,5))

# Fit the classifier
In [22]: with gw.config.update(ref_res=100):
   ....:     with gw.open(l8_224078_20200518, chunks=128) as src:
   ....:         X, Xy, clf = fit(src, pl, labels, col="lc")
   ....:         y = predict(src, X, clf)
   ....:         y.plot(robust=True, ax=ax)
   ....: 

In [23]: plt.tight_layout(pad=1)

Train an unsupervised classifier and predict#

Unsupervised classifiers can also be used in a pipeline

In [24]: cl = Pipeline([ ('scaler', StandardScaler()),
   ....:                 ('pca', PCA()),
   ....:                 ('clf', KMeans(n_clusters=3, random_state=0))])
   ....: 

In [25]: fig, ax = plt.subplots(dpi=200,figsize=(5,5))

# fit and predict unsupervised classifier
In [26]: with gw.config.update(ref_res=300):
   ....:     with gw.open(l8_224078_20200518) as src:
   ....:         X, Xy, clf = fit(src, cl)
   ....:         y = predict(src, X, clf)
   ....:         y.plot(robust=True, ax=ax)
   ....: 

In [27]: plt.tight_layout(pad=1)

In [28]: fig, ax = plt.subplots(dpi=200,figsize=(5,5))

# Fit_predict unsupervised classifier
In [29]: with gw.config.update(ref_res=300):
   ....:     with gw.open(l8_224078_20200518) as src:
   ....:         y = fit_predict(src, cl)
   ....:         y.plot(robust=True, ax=ax)
   ....: 

In [30]: plt.tight_layout(pad=1)

Predict with cross validation and parameter tuning#

Cross-validation and parameter tuning is now possible

In [31]: from sklearn.model_selection import GridSearchCV, KFold

In [32]: from sklearn_xarray.model_selection import CrossValidatorWrapper

In [33]: cv = CrossValidatorWrapper(KFold())

In [34]: gridsearch = GridSearchCV(
   ....:     pl,
   ....:     cv=cv,
   ....:     scoring='balanced_accuracy',
   ....:     param_grid={"pca__n_components": [1, 2, 3]}
   ....: )
   ....: 

In [35]: fig, ax = plt.subplots(dpi=200,figsize=(5,5))

In [36]: with gw.config.update(ref_res=300):
   ....:     with gw.open(l8_224078_20200518) as src:
   ....:         X, Xy, clf = fit(src, pl, labels, col="lc")
   ....: 

        # fit cross valiation and parameter tuning
        # NOTE: must unpack * object Xy
In [37]: gridsearch.fit(*Xy)
Out[37]: 
GridSearchCV(cv=<sklearn_xarray.model_selection.CrossValidatorWrapper object at 0x7f09aa9c8c40>,
             estimator=Pipeline(steps=[('scaler', StandardScaler()),
                                       ('pca', PCA()), ('clf', GaussianNB())]),
             param_grid={'pca__n_components': [1, 2, 3]},
             scoring='balanced_accuracy')

In [38]: print(gridsearch.best_params_)
{'pca__n_components': 1}

In [39]: print(gridsearch.best_score_)
0.7333333333333333

        # get set tuned parameters
        # Note: predict(gridsearch.best_model_) not currently supported
In [40]: clf.set_params(**gridsearch.best_params_)
Out[40]: 
Pipeline(steps=[('scaler',
                 EstimatorWrapper(copy=True, estimator=StandardScaler(),
                                  reshapes='band', with_mean=True,
                                  with_std=True)),
                ('pca',
                 EstimatorWrapper(copy=True, estimator=PCA(),
                                  iterated_power='auto', n_components=1,
                                  n_oversamples=10,
                                  power_iteration_normalizer='auto',
                                  random_state=None, reshapes='band',
                                  svd_solver='auto', tol=0.0, whiten=False)),
                ('clf',
                 EstimatorWrapper(estimator=GaussianNB(), priors=None,
                                  reshapes='band', var_smoothing=1e-09))])

In [41]: y = predict(src, X, clf)

In [42]: y.plot(robust=True, ax=ax)
Out[42]: <matplotlib.collections.QuadMesh at 0x7f09aa9c8bb0>

In [43]: plt.tight_layout(pad=1)

Save prediction output#

y.gw.save('output.tif', overwrite=True)