Source code for geowombat.detect.detectors

"""Object detectors for geowombat.

Wraps Ultralytics YOLO and TorchGeo detection models behind a single
geospatial-aware API: tiled windowed inference with cross-tile NMS and
georeferenced ``GeoDataFrame`` outputs.

Requires: ``pip install geowombat[detect]`` (YOLO/TorchGeo)
          ``pip install geowombat[sam]`` (SAMRefiner)

Example
-------
>>> import geowombat as gw
>>> from geowombat.detect import YOLODetector
>>> det = YOLODetector(weights='yolov8n.pt')
>>> with gw.open('aerial.tif') as src:
...     boxes = det.predict(src, tile_size=640, conf=0.25,
...                         band_indices=[2, 1, 0], scale=(0, 3000))
>>> boxes.to_file('detections.gpkg')
"""

import warnings
from pathlib import Path

import geopandas as gpd
import numpy as np
from shapely.geometry import Polygon, box as shapely_box

from ..ml._labels import resolve_band_indices
from ._tiling import overlapped_windows
from .data import _prepare_rgb_tile


# ---------------------------------------------------------------------------
# Shared helpers
# ---------------------------------------------------------------------------

def _pixel_box_to_polygon(affine, tile_x0, tile_y0, x1, y1, x2, y2):
    """Convert pixel-space corners on a tile to a CRS polygon."""
    ul_x, ul_y = affine * (tile_x0 + x1, tile_y0 + y1)
    lr_x, lr_y = affine * (tile_x0 + x2, tile_y0 + y2)
    xmin, xmax = sorted((ul_x, lr_x))
    ymin, ymax = sorted((ul_y, lr_y))
    return shapely_box(xmin, ymin, xmax, ymax)


def _pixel_quad_to_polygon(affine, tile_x0, tile_y0, quad_xy):
    """Convert pixel-space (4,2) corner array on a tile to CRS polygon."""
    coords = []
    for x, y in quad_xy:
        cx, cy = affine * (tile_x0 + x, tile_y0 + y)
        coords.append((cx, cy))
    return Polygon(coords)


def _iou_geom(a, b):
    """IoU between two shapely geometries (assumed in same CRS)."""
    if a.is_empty or b.is_empty:
        return 0.0
    inter = a.intersection(b).area
    if inter <= 0:
        return 0.0
    union = a.area + b.area - inter
    return inter / union if union > 0 else 0.0


def _nms_geodataframe(gdf, iou_threshold=0.5):
    """Greedy NMS across a GeoDataFrame, class-aware.

    Expects columns: ``geometry``, ``score``, ``class_id``.
    """
    if gdf.empty:
        return gdf
    keep_idx = []
    for cls_id, sub in gdf.groupby('class_id', sort=False):
        sub = sub.sort_values('score', ascending=False)
        kept_geoms = []
        for idx, row in sub.iterrows():
            geom = row.geometry
            keep = True
            for kg in kept_geoms:
                if _iou_geom(geom, kg) > iou_threshold:
                    keep = False
                    break
            if keep:
                keep_idx.append(idx)
                kept_geoms.append(geom)
    out = gdf.loc[keep_idx].sort_values(
        'score', ascending=False,
    ).reset_index(drop=True)
    return out


def _resolve_device(device):
    try:
        import torch
    except ImportError as e:
        raise ImportError(
            "Detectors require PyTorch. Install with "
            "`pip install geowombat[detect]`."
        ) from e
    if device != 'auto':
        return device
    if torch.cuda.is_available():
        return 'cuda'
    # CPU fallback — but check if there's an NVIDIA GPU on the system
    # that this torch build just can't see. That's almost always a
    # CPU-only torch install; tell the user how to fix it.
    import shutil
    import subprocess
    if shutil.which('nvidia-smi') is not None:
        try:
            r = subprocess.run(
                ['nvidia-smi', '--query-gpu=name', '--format=csv,noheader'],
                capture_output=True, text=True, timeout=2,
            )
        except (subprocess.SubprocessError, OSError):
            r = None
        if r is not None and r.returncode == 0 and r.stdout.strip():
            gpu_name = r.stdout.strip().splitlines()[0]
            warnings.warn(
                f"geowombat: NVIDIA GPU detected ({gpu_name}) but PyTorch "
                f"can't use it — falling back to CPU. You likely have a "
                f"CPU-only torch build. Install a CUDA build, e.g.:\n"
                f"  pip install torch --index-url "
                f"https://download.pytorch.org/whl/cu124\n"
                f"(see https://pytorch.org/get-started/locally/ for your "
                f"CUDA version).",
                UserWarning, stacklevel=2,
            )
    return 'cpu'


# ---------------------------------------------------------------------------
# Base class
# ---------------------------------------------------------------------------

class GeoWombatDetector:
    """Base class for geowombat object detectors.

    Subclasses must implement ``_detect_tile(rgb_array)`` returning a
    list of ``(x1, y1, x2, y2, score, class_id)`` tuples (axis-aligned)
    or ``(quad_xy_array, score, class_id)`` for oriented boxes. They
    should also set ``class_names`` (list[str]) and ``oriented`` (bool).

    Subclasses can optionally override ``_detect_batch`` to run a list
    of tiles through the underlying model in one call (much faster on
    GPU, modestly faster on CPU). The default falls back to a per-tile
    loop.
    """
    class_names = None
    oriented = False

    def _detect_tile(self, rgb_array):
        raise NotImplementedError

    def _detect_batch(self, rgb_list, conf=0.25, max_det=None):
        """Default: per-tile loop. Override for true batched inference."""
        return [
            self._detect_tile(rgb, conf=conf, max_det=max_det)
            for rgb in rgb_list
        ]

    def predict(
        self,
        src,
        tile_size=640,
        overlap=0.2,
        conf=0.25,
        band_indices=None,
        scale=None,
        nms_iou=0.5,
        max_det=None,
        batch_size=4,
        progress=False,
    ):
        """Run tiled, georeferenced inference over a raster.

        Parameters
        ----------
        src : xarray.DataArray
            Raster opened with ``gw.open()``.
        tile_size : int
            Square tile edge in pixels. Default 640.
        overlap : float
            Fractional overlap between adjacent tiles. Default 0.2.
        conf : float
            Confidence threshold for tile-level detections. Default 0.25.
        band_indices : list of int, optional
            Three band indices for R, G, B. See ``build_yolo_dataset``.
        scale : tuple of (lo, hi), optional
            Linear stretch before inference.
        nms_iou : float
            IoU threshold for cross-tile NMS. Default 0.5.
        max_det : int, optional
            Cap on detections per tile (passed to the underlying model
            when supported). Default unlimited.
        batch_size : int
            Number of tiles sent through the model per inference call.
            Larger batches keep the GPU busy and reduce per-tile Python
            overhead. On CPU, modest batching also helps. Default 4.
        progress : bool
            Show a tqdm progress bar. Default False.

        Returns
        -------
        geopandas.GeoDataFrame
            One row per detection with columns ``geometry, class_id,
            class_name, score, tile_id`` in the source CRS.
        """
        affine = src.gw.affine
        crs = src.gw.crs_to_pyproj
        band_indices = resolve_band_indices(src, band_indices)

        tiles = list(overlapped_windows(src, tile_size, overlap))
        iterator = tiles
        if progress:
            try:
                from tqdm import tqdm
                iterator = tqdm(tiles, desc=self.__class__.__name__)
            except ImportError:
                pass

        records = []
        # Per-tile metadata buffered alongside the RGB array, so a flushed
        # batch knows where each tile's detections belong in the source CRS.
        batch_rgbs = []
        batch_meta = []  # list of (tile_id, x0, y0, w_block, h_block)

        def _flush(rgbs, meta):
            if not rgbs:
                return
            batch_dets = self._detect_batch(
                rgbs, conf=conf, max_det=max_det,
            )
            for (tile_id, x0, y0, w_block, h_block), dets in zip(
                meta, batch_dets,
            ):
                for det in dets:
                    if self.oriented:
                        quad, score, cls_id = det
                        # Reject detections whose center falls in padded
                        # region (outside the true image footprint).
                        cx = float(np.mean(quad[:, 0]))
                        cy = float(np.mean(quad[:, 1]))
                        if cx > w_block or cy > h_block:
                            continue
                        geom = _pixel_quad_to_polygon(
                            affine, x0, y0, quad,
                        )
                    else:
                        bx1, by1, bx2, by2, score, cls_id = det
                        cx = (bx1 + bx2) / 2.0
                        cy = (by1 + by2) / 2.0
                        if cx > w_block or cy > h_block:
                            continue
                        geom = _pixel_box_to_polygon(
                            affine, x0, y0, bx1, by1, bx2, by2,
                        )

                    if self.class_names and cls_id < len(self.class_names):
                        name = self.class_names[int(cls_id)]
                    else:
                        name = str(int(cls_id))

                    records.append({
                        'geometry': geom,
                        'class_id': int(cls_id),
                        'class_name': name,
                        'score': float(score),
                        'tile_id': int(tile_id),
                    })

        for tile_id, (r, c, win) in enumerate(iterator):
            y0 = win.row_off
            x0 = win.col_off
            y1 = y0 + win.height
            x1 = x0 + win.width
            block = src.isel(
                y=slice(y0, y1), x=slice(x0, x1),
            ).values
            if block.ndim == 4:
                block = block[0]

            h_block, w_block = block.shape[1], block.shape[2]
            pad_h = tile_size - h_block
            pad_w = tile_size - w_block
            if pad_h > 0 or pad_w > 0:
                padded = np.zeros(
                    (block.shape[0], tile_size, tile_size),
                    dtype=block.dtype,
                )
                padded[:, :h_block, :w_block] = block
                block = padded

            rgb = _prepare_rgb_tile(block, band_indices, scale)

            batch_rgbs.append(rgb)
            batch_meta.append((tile_id, x0, y0, w_block, h_block))

            if len(batch_rgbs) >= batch_size:
                _flush(batch_rgbs, batch_meta)
                batch_rgbs, batch_meta = [], []

        _flush(batch_rgbs, batch_meta)

        if not records:
            return gpd.GeoDataFrame(
                {
                    'geometry': gpd.GeoSeries([], crs=crs),
                    'class_id': [],
                    'class_name': [],
                    'score': [],
                    'tile_id': [],
                },
                geometry='geometry', crs=crs,
            )
        gdf = gpd.GeoDataFrame(records, geometry='geometry', crs=crs)
        gdf = _nms_geodataframe(gdf, iou_threshold=nms_iou)
        return gdf


# ---------------------------------------------------------------------------
# YOLO (Ultralytics)
# ---------------------------------------------------------------------------

[docs]class YOLODetector(GeoWombatDetector): """Ultralytics YOLO detector for georeferenced rasters. Supports axis-aligned (default) and oriented bounding boxes. The underlying model is any path accepted by ``ultralytics.YOLO`` — pretrained weights ('yolov8n.pt', 'yolo11n.pt', 'yolov8n-obb.pt') or a custom-trained checkpoint. NOTE: Ultralytics is licensed AGPL-3.0; ensure your use case is compatible before deploying. Parameters ---------- weights : str or Path YOLO weights file. Default 'yolov8n.pt'. classes : list of str, optional Override class names. If None, names come from the model. oriented : bool Set to True when using an OBB weight (file ends in ``-obb.pt``). Auto-detected from the filename if not specified. device : str 'cpu', 'cuda', or 'auto'. Default 'auto'. imgsz : int Inference size passed to YOLO. Default matches ``tile_size`` in ``predict()``. Example ------- >>> det = YOLODetector(weights='yolov8n-obb.pt', oriented=True) >>> with gw.open('aerial.tif') as src: ... gdf = det.predict(src, conf=0.3) """ def __init__(self, weights='yolov8n.pt', classes=None, oriented=None, device='auto', imgsz=None): try: from ultralytics import YOLO except ImportError as e: raise ImportError( "YOLODetector requires ultralytics. " "Install with: pip install geowombat[detect]" ) from e self.weights = str(weights) self.device = _resolve_device(device) self.imgsz = imgsz self._model = YOLO(self.weights) try: self._model.to(self.device) except Exception: pass if oriented is None: stem = Path(self.weights).stem.lower() oriented = 'obb' in stem self.oriented = bool(oriented) model_names = getattr(self._model, 'names', None) or {} if classes is not None: self.class_names = list(classes) elif isinstance(model_names, dict): ordered = sorted(model_names.items(), key=lambda kv: kv[0]) self.class_names = [v for _, v in ordered] else: self.class_names = list(model_names)
[docs] def fit(self, dataset_yaml, epochs=50, imgsz=640, **kwargs): """Fine-tune YOLO on a dataset produced by ``build_yolo_dataset``. Thin wrapper around ``ultralytics.YOLO.train``. Parameters ---------- dataset_yaml : str or Path Path to ``data.yaml`` written by ``build_yolo_dataset``. epochs : int Training epochs. Default 50. imgsz : int Training image size. Default 640. **kwargs Additional kwargs forwarded to ``YOLO.train``. """ results = self._model.train( data=str(dataset_yaml), epochs=epochs, imgsz=imgsz, device=self.device, **kwargs, ) # Refresh class names if the model was retrained with new ones model_names = getattr(self._model, 'names', None) or {} if isinstance(model_names, dict): ordered = sorted(model_names.items(), key=lambda kv: kv[0]) self.class_names = [v for _, v in ordered] return results
def _result_to_dets(self, r): """Convert a single ultralytics ``Results`` object to our tuple list.""" out = [] if self.oriented and getattr(r, 'obb', None) is not None: obb = r.obb try: xyxyxyxy = obb.xyxyxyxy.cpu().numpy() # (N, 4, 2) confs = obb.conf.cpu().numpy() clses = obb.cls.cpu().numpy().astype(int) except AttributeError: return [] for quad, score, cls_id in zip(xyxyxyxy, confs, clses): out.append((np.asarray(quad), float(score), int(cls_id))) return out boxes = getattr(r, 'boxes', None) if boxes is None or boxes.xyxy is None: return [] try: xyxy = boxes.xyxy.cpu().numpy() confs = boxes.conf.cpu().numpy() clses = boxes.cls.cpu().numpy().astype(int) except AttributeError: return [] for (x1, y1, x2, y2), score, cls_id in zip(xyxy, confs, clses): out.append( (float(x1), float(y1), float(x2), float(y2), float(score), int(cls_id)) ) return out def _detect_tile(self, rgb_array, conf=0.25, max_det=None): # Single-image path for backwards compatibility. The batched # path in `predict()` goes through `_detect_batch` instead. return self._detect_batch( [rgb_array], conf=conf, max_det=max_det, )[0] def _detect_batch(self, rgb_list, conf=0.25, max_det=None): if not rgb_list: return [] imgsz = self.imgsz or max( rgb_list[0].shape[0], rgb_list[0].shape[1], ) predict_kwargs = dict( source=rgb_list, # Ultralytics natively batches a list of arrays conf=conf, imgsz=imgsz, device=self.device, verbose=False, ) if max_det is not None: predict_kwargs['max_det'] = int(max_det) results = self._model.predict(**predict_kwargs) if not results: return [[] for _ in rgb_list] # Ultralytics returns one Results per input image when given a list. return [self._result_to_dets(r) for r in results]
# --------------------------------------------------------------------------- # TorchGeo detection wrapper # ---------------------------------------------------------------------------
[docs]class TorchGeoDetector(GeoWombatDetector): """Detection wrapper around TorchGeo / torchvision detection models. Supports Faster R-CNN and RetinaNet via torchvision with optional pretrained weights from TorchGeo (e.g. xView). Axis-aligned only. Parameters ---------- model : {'faster-rcnn', 'retinanet'} Detection head. Default 'faster-rcnn'. weights : str, optional TorchGeo weights enum string (e.g. 'FCN_RESNET50_XVIEW') or path to a state dict. If None, uses torchvision COCO pretrained. num_classes : int, optional Number of classes including background. Required when loading a custom-trained checkpoint. classes : list of str, optional Class names corresponding to non-background ids 1..N. device : str 'cpu', 'cuda', or 'auto'. Default 'auto'. Notes ----- For aerial / satellite imagery, the most useful TorchGeo weights are typically 'FASTERRCNN_RESNET50_FPN_XVIEW' (when available in your TorchGeo version). Check ``torchgeo.models`` for the current catalogue. """ oriented = False def __init__(self, model='faster-rcnn', weights=None, num_classes=None, classes=None, device='auto'): try: import torch import torchvision from torchvision.models.detection import ( fasterrcnn_resnet50_fpn, retinanet_resnet50_fpn, ) except ImportError as e: raise ImportError( "TorchGeoDetector requires torch + torchvision. " "Install with: pip install geowombat[dl,detect]" ) from e self._torch = torch self.device = _resolve_device(device) self.model_name = model torchgeo_weights = None state_dict_path = None if isinstance(weights, str): # Try TorchGeo enum first; fall back to a file path. try: import torchgeo.models as tgm enum_name, _, member = weights.partition('.') wenum = getattr(tgm, enum_name, None) if wenum is not None and member: torchgeo_weights = getattr(wenum, member, None) elif wenum is not None: torchgeo_weights = wenum except (ImportError, AttributeError): torchgeo_weights = None if torchgeo_weights is None and Path(weights).exists(): state_dict_path = weights if model == 'faster-rcnn': net = fasterrcnn_resnet50_fpn( weights='DEFAULT' if (torchgeo_weights is None and state_dict_path is None) else None, num_classes=num_classes, ) elif model == 'retinanet': net = retinanet_resnet50_fpn( weights='DEFAULT' if (torchgeo_weights is None and state_dict_path is None) else None, num_classes=num_classes, ) else: raise ValueError( f"Unknown model '{model}'. Use 'faster-rcnn' or 'retinanet'." ) if torchgeo_weights is not None: try: state = torchgeo_weights.get_state_dict(progress=True) net.load_state_dict(state, strict=False) except Exception as e: warnings.warn( f"Could not load TorchGeo weights '{weights}': {e}. " "Falling back to torchvision defaults." ) elif state_dict_path is not None: state = torch.load(state_dict_path, map_location='cpu') if isinstance(state, dict) and 'state_dict' in state: state = state['state_dict'] net.load_state_dict(state, strict=False) self._model = net.to(self.device).eval() if classes is not None: self.class_names = list(classes) else: # COCO defaults if using torchvision weights self.class_names = self._coco_names() @staticmethod def _coco_names(): # torchvision's COCO label indexing has 91 entries (the original # COCO category id space) with N/A gaps where ids 12, 26, 29, 30, # 45, 66, 68, 69, 71, 83 are unused. Matches # torchvision.models.detection.faster_rcnn_resnet50_fpn DEFAULT # weights, so cls_id can be looked up directly. return [ '__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', 'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush', ] def _output_to_dets(self, out, conf=0.25, max_det=None): """Convert a single torchvision detection output to tuple list.""" boxes = out['boxes'].cpu().numpy() scores = out['scores'].cpu().numpy() labels = out['labels'].cpu().numpy().astype(int) keep = scores >= conf boxes = boxes[keep] scores = scores[keep] labels = labels[keep] if max_det is not None and len(scores) > max_det: order = np.argsort(-scores)[:max_det] boxes = boxes[order] scores = scores[order] labels = labels[order] return [ (float(b[0]), float(b[1]), float(b[2]), float(b[3]), float(s), int(c)) for b, s, c in zip(boxes, scores, labels) ] def _detect_tile(self, rgb_array, conf=0.25, max_det=None): return self._detect_batch( [rgb_array], conf=conf, max_det=max_det, )[0] def _detect_batch(self, rgb_list, conf=0.25, max_det=None): if not rgb_list: return [] torch = self._torch # (H, W, 3) uint8 → (3, H, W) float [0, 1], moved to device tensors = [ (torch.from_numpy(rgb).float().permute(2, 0, 1) / 255.0).to( self.device, ) for rgb in rgb_list ] with torch.no_grad(): outputs = self._model(tensors) return [ self._output_to_dets(o, conf=conf, max_det=max_det) for o in outputs ]
# --------------------------------------------------------------------------- # SAM refinement # ---------------------------------------------------------------------------
[docs]class SAMRefiner: """Refine bounding-box detections into polygons using SAM. Each input box is used as a prompt to the Segment Anything model (or SAM2 if installed) to produce a precise polygon mask, then polygonized back into vector geometry in the source CRS. Requires: ``pip install geowombat[sam]`` Parameters ---------- checkpoint : str or Path Path to a SAM checkpoint (``sam_vit_b.pth``, etc). model_type : {'vit_b', 'vit_l', 'vit_h'} SAM backbone size. Default 'vit_b'. device : str 'cpu', 'cuda', or 'auto'. Default 'auto'. Example ------- >>> ref = SAMRefiner(checkpoint='sam_vit_b.pth') >>> polys = ref.refine(src, boxes_gdf, band_indices=[2, 1, 0], ... scale=(0, 3000)) """ def __init__(self, checkpoint, model_type='vit_b', device='auto'): try: from segment_anything import sam_model_registry, SamPredictor except ImportError as e: raise ImportError( "SAMRefiner requires segment-anything. " "Install with: pip install geowombat[sam]" ) from e self.device = _resolve_device(device) sam = sam_model_registry[model_type](checkpoint=str(checkpoint)) sam.to(self.device) self._predictor = SamPredictor(sam)
[docs] def refine(self, src, boxes_gdf, band_indices=None, scale=None, pad_pixels=8, simplify_tolerance=0.5): """Refine boxes to polygon masks. Parameters ---------- src : xarray.DataArray Source raster (must match ``boxes_gdf.crs``). boxes_gdf : geopandas.GeoDataFrame Detector output. Each row's geometry should be the bbox. band_indices : list of int, optional RGB bands for SAM input. scale : tuple of (lo, hi), optional Stretch range. pad_pixels : int Pad around each box when reading the chip. Default 8. simplify_tolerance : float Polygon simplification tolerance in CRS units. Default 0.5. Returns ------- geopandas.GeoDataFrame Same columns as input; geometry replaced by polygons. """ import numpy as np from shapely.geometry import Polygon as ShPoly try: from rasterio.features import shapes as rio_shapes except ImportError as e: raise ImportError( "SAMRefiner.refine requires rasterio." ) from e if boxes_gdf.empty: return boxes_gdf.copy() if boxes_gdf.crs is None or src.gw.crs_to_pyproj.to_epsg() != \ boxes_gdf.crs.to_epsg(): boxes_gdf = boxes_gdf.to_crs(src.gw.crs_to_pyproj) band_indices = resolve_band_indices(src, band_indices) affine = src.gw.affine inv = ~affine out_geoms = [] for _, row in boxes_gdf.iterrows(): minx, miny, maxx, maxy = row.geometry.bounds col_ul, row_ul = inv * (minx, maxy) col_lr, row_lr = inv * (maxx, miny) x0 = int(np.floor(min(col_ul, col_lr))) - pad_pixels y0 = int(np.floor(min(row_ul, row_lr))) - pad_pixels x1 = int(np.ceil(max(col_ul, col_lr))) + pad_pixels y1 = int(np.ceil(max(row_ul, row_lr))) + pad_pixels x0 = max(0, x0) y0 = max(0, y0) x1 = min(src.gw.ncols, x1) y1 = min(src.gw.nrows, y1) if x1 <= x0 or y1 <= y0: out_geoms.append(row.geometry) continue chip = src.isel( y=slice(y0, y1), x=slice(x0, x1), ).values if chip.ndim == 4: chip = chip[0] rgb = _prepare_rgb_tile(chip, band_indices, scale) self._predictor.set_image(rgb) box_px = np.array([ min(col_ul, col_lr) - x0, min(row_ul, row_lr) - y0, max(col_ul, col_lr) - x0, max(row_ul, row_lr) - y0, ], dtype=np.float32) masks, scores, _ = self._predictor.predict( box=box_px, multimask_output=False, ) mask = masks[0].astype(np.uint8) from affine import Affine chip_affine = Affine( affine.a, affine.b, affine.c + x0 * affine.a + y0 * affine.b, affine.d, affine.e, affine.f + x0 * affine.d + y0 * affine.e, ) polys = [] for geom_dict, val in rio_shapes(mask, transform=chip_affine): if val != 1: continue coords = geom_dict.get('coordinates', []) if not coords: continue outer = coords[0] holes = coords[1:] if len(coords) > 1 else None try: p = ShPoly(outer, holes=holes) except Exception: continue if simplify_tolerance > 0: p = p.simplify(simplify_tolerance) if not p.is_empty and p.is_valid: polys.append(p) if polys: # Keep largest polygon — masks can occasionally produce # disconnected components near the prompt box edge. polys.sort(key=lambda g: g.area, reverse=True) out_geoms.append(polys[0]) else: out_geoms.append(row.geometry) result = boxes_gdf.copy() result['geometry'] = out_geoms return result