Skip to content

object_detect module

High-level functions for multi-class object detection.

This module provides convenience functions for training, evaluating, and running inference with Mask R-CNN models on COCO-format datasets, including support for the NWPU-VHR-10 remote sensing benchmark.

batch_multiclass_detection(image_paths, output_dir, model_path=None, model_name=None, num_classes=11, class_names=None, window_size=512, overlap=256, confidence_threshold=0.5, nms_threshold=0.3, batch_size=4, num_channels=3, device=None, visualize=True, cols=2, figsize=(16, 16), cleanup=True, output_path=None, repo_id=None, **kwargs)

Run multi-class object detection on multiple images.

Iterates over a list of image paths, calls :func:multiclass_detection for each, and optionally displays a grid of results with colored bounding boxes.

Parameters:

Name Type Description Default
image_paths list of str

Paths to input images.

required
output_dir str

Directory for intermediate detection output files.

required
model_path str

Path to trained model weights. If None, downloads the pretrained NWPU-VHR-10 model.

None
model_name str

Detection model architecture name. If None, auto-detected from checkpoint.

None
num_classes int

Number of classes including background. Defaults to 11.

11
class_names list

List of class names (index 0 = background).

None
window_size int

Sliding window size. Defaults to 512.

512
overlap int

Window overlap in pixels. Defaults to 256.

256
confidence_threshold float

Minimum detection score. Defaults to 0.5.

0.5
nms_threshold float

IoU threshold for NMS. Defaults to 0.3.

0.3
batch_size int

Inference batch size. Defaults to 4.

4
num_channels int

Number of input image channels. Defaults to 3.

3
device device

Compute device.

None
visualize bool

Whether to display a grid of results. Defaults to True.

True
cols int

Number of columns in the visualization grid. Defaults to 2.

2
figsize tuple

Figure size for the visualization grid. Defaults to (16, 16).

(16, 16)
cleanup bool

Whether to remove intermediate output files after visualization. Defaults to True.

True
output_path str

Path to save the visualization figure. If None, displays interactively.

None
repo_id str

HuggingFace Hub repository ID for downloading the model.

None
**kwargs Any

Additional keyword arguments passed to :func:multiclass_detection.

{}

Returns:

Type Description
List[Tuple[str, float, List[Dict]]]

list of tuple: Each tuple contains (result_path, inference_time,

List[Tuple[str, float, List[Dict]]]

detections_list) from :func:multiclass_detection.

Source code in geoai/object_detect.py
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
def batch_multiclass_detection(
    image_paths: List[str],
    output_dir: str,
    model_path: Optional[str] = None,
    model_name: Optional[str] = None,
    num_classes: int = 11,
    class_names: Optional[List[str]] = None,
    window_size: int = 512,
    overlap: int = 256,
    confidence_threshold: float = 0.5,
    nms_threshold: float = 0.3,
    batch_size: int = 4,
    num_channels: int = 3,
    device: Optional[torch.device] = None,
    visualize: bool = True,
    cols: int = 2,
    figsize: Tuple[int, int] = (16, 16),
    cleanup: bool = True,
    output_path: Optional[str] = None,
    repo_id: Optional[str] = None,
    **kwargs: Any,
) -> List[Tuple[str, float, List[Dict]]]:
    """Run multi-class object detection on multiple images.

    Iterates over a list of image paths, calls :func:`multiclass_detection`
    for each, and optionally displays a grid of results with colored
    bounding boxes.

    Args:
        image_paths (list of str): Paths to input images.
        output_dir (str): Directory for intermediate detection output files.
        model_path (str, optional): Path to trained model weights. If None,
            downloads the pretrained NWPU-VHR-10 model.
        model_name (str, optional): Detection model architecture name. If
            None, auto-detected from checkpoint.
        num_classes (int): Number of classes including background. Defaults
            to 11.
        class_names (list, optional): List of class names (index 0 =
            background).
        window_size (int): Sliding window size. Defaults to 512.
        overlap (int): Window overlap in pixels. Defaults to 256.
        confidence_threshold (float): Minimum detection score. Defaults to
            0.5.
        nms_threshold (float): IoU threshold for NMS. Defaults to 0.3.
        batch_size (int): Inference batch size. Defaults to 4.
        num_channels (int): Number of input image channels. Defaults to 3.
        device (torch.device, optional): Compute device.
        visualize (bool): Whether to display a grid of results. Defaults to
            True.
        cols (int): Number of columns in the visualization grid. Defaults
            to 2.
        figsize (tuple): Figure size for the visualization grid. Defaults to
            (16, 16).
        cleanup (bool): Whether to remove intermediate output files after
            visualization. Defaults to True.
        output_path (str, optional): Path to save the visualization figure.
            If None, displays interactively.
        repo_id (str, optional): HuggingFace Hub repository ID for
            downloading the model.
        **kwargs: Additional keyword arguments passed to
            :func:`multiclass_detection`.

    Returns:
        list of tuple: Each tuple contains (result_path, inference_time,
        detections_list) from :func:`multiclass_detection`.
    """
    from PIL import Image as PILImage

    os.makedirs(output_dir, exist_ok=True)

    results = []
    for idx, img_path in enumerate(image_paths):
        basename = os.path.splitext(os.path.basename(img_path))[0]
        out_path = os.path.join(output_dir, f"{basename}_detection.tif")

        result = multiclass_detection(
            input_path=img_path,
            output_path=out_path,
            model_path=model_path,
            model_name=model_name,
            num_classes=num_classes,
            class_names=class_names,
            window_size=window_size,
            overlap=overlap,
            confidence_threshold=confidence_threshold,
            nms_threshold=nms_threshold,
            batch_size=batch_size,
            num_channels=num_channels,
            device=device,
            repo_id=repo_id,
            **kwargs,
        )
        results.append(result)

    if visualize:
        cmap = plt.cm.get_cmap("tab10", 10)
        n = len(image_paths)
        rows = math.ceil(n / cols)
        fig, axes = plt.subplots(rows, cols, figsize=figsize)
        if n == 1:
            axes = [axes]
        else:
            axes = axes.flatten()

        for idx, (img_path, (_, _, dets)) in enumerate(zip(image_paths, results)):
            img = PILImage.open(img_path)
            axes[idx].imshow(img)
            axes[idx].set_title(
                f"{os.path.basename(img_path)} ({len(dets)} detections)",
                fontsize=10,
            )
            axes[idx].axis("off")

            for det in dets:
                box = det["box"]
                label = det["label"]
                score = det["score"]
                color = cmap(label % 10)
                rect = plt.Rectangle(
                    (box[0], box[1]),
                    box[2] - box[0],
                    box[3] - box[1],
                    linewidth=2,
                    edgecolor=color,
                    facecolor="none",
                )
                axes[idx].add_patch(rect)
                name = (
                    class_names[label]
                    if class_names and label < len(class_names)
                    else str(label)
                )
                axes[idx].text(
                    box[0],
                    box[1] - 3,
                    f"{name}: {score:.2f}",
                    color="white",
                    fontsize=7,
                    bbox=dict(boxstyle="round,pad=0.2", facecolor=color, alpha=0.7),
                )

        for ax_idx in range(n, len(axes)):
            axes[ax_idx].axis("off")

        plt.tight_layout()

        if output_path:
            plt.savefig(output_path, dpi=150, bbox_inches="tight")
            plt.close()
        else:
            plt.show()

    if cleanup:
        for result_path, _, _ in results:
            if os.path.exists(result_path):
                os.remove(result_path)

    return results

detections_to_geodataframe(detections, geotiff_path, class_names=None, use_mask_geometry=False, simplify_tolerance=0.0)

Convert detections to a GeoDataFrame with geospatial coordinates.

Converts pixel-space detections to geospatial coordinates using the CRS and transform from the source GeoTIFF. By default, bounding box rectangles are used as geometry. When use_mask_geometry=True, the actual instance mask is vectorized into polygon geometry instead.

Parameters:

Name Type Description Default
detections list

List of detection dicts, each with keys: mask (np.ndarray), score (float), box (list of pixel coords), and optionally label (int), instance_id (int), and mask_offset (tuple of y, x, h, w) for compact masks.

required
geotiff_path str

Path to the source GeoTIFF (for CRS and transform).

required
class_names list

List of class names (index 0 = background).

None
use_mask_geometry bool

If True, convert instance masks to polygon geometries using rasterio.features.shapes instead of using bounding boxes. Defaults to False.

False
simplify_tolerance float

Tolerance for polygon simplification in georeferenced units. Only used when use_mask_geometry=True. Set to 0 to disable simplification. Defaults to 0.0.

0.0

Returns:

Type Description
Any

geopandas.GeoDataFrame: GeoDataFrame with columns: geometry, class_id,

Any

class_name, score, instance_id, area_pixels.

Source code in geoai/object_detect.py
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
def detections_to_geodataframe(
    detections: List[Dict],
    geotiff_path: str,
    class_names: Optional[List[str]] = None,
    use_mask_geometry: bool = False,
    simplify_tolerance: float = 0.0,
) -> Any:
    """Convert detections to a GeoDataFrame with geospatial coordinates.

    Converts pixel-space detections to geospatial coordinates using the
    CRS and transform from the source GeoTIFF. By default, bounding box
    rectangles are used as geometry. When ``use_mask_geometry=True``, the
    actual instance mask is vectorized into polygon geometry instead.

    Args:
        detections (list): List of detection dicts, each with keys:
            mask (np.ndarray), score (float), box (list of pixel coords),
            and optionally label (int), instance_id (int), and
            mask_offset (tuple of y, x, h, w) for compact masks.
        geotiff_path (str): Path to the source GeoTIFF (for CRS and transform).
        class_names (list, optional): List of class names (index 0 = background).
        use_mask_geometry (bool): If True, convert instance masks to polygon
            geometries using rasterio.features.shapes instead of using
            bounding boxes. Defaults to False.
        simplify_tolerance (float): Tolerance for polygon simplification
            in georeferenced units. Only used when use_mask_geometry=True.
            Set to 0 to disable simplification. Defaults to 0.0.

    Returns:
        geopandas.GeoDataFrame: GeoDataFrame with columns: geometry, class_id,
        class_name, score, instance_id, area_pixels.
    """
    import geopandas as gpd
    import numpy as np
    import rasterio
    from shapely.geometry import Polygon

    if len(detections) == 0:
        return gpd.GeoDataFrame(
            columns=[
                "geometry",
                "class_id",
                "class_name",
                "score",
                "instance_id",
                "area_pixels",
            ]
        )

    with rasterio.open(geotiff_path) as src:
        transform = src.transform
        crs = src.crs

    records = []
    for idx, det in enumerate(detections):
        bx = det["box"]
        label = det.get("label", 1)
        score = det["score"]
        instance_id = det.get("instance_id", idx + 1)
        area_pixels = int(det["mask"].sum()) if "mask" in det else 0

        geom = None

        if use_mask_geometry and "mask" in det:
            from rasterio.features import shapes as rasterio_shapes
            from rasterio.transform import Affine
            from shapely.geometry import shape
            from shapely.ops import unary_union

            mask = det["mask"]

            # Determine crop region from mask_offset (compact) or bbox
            if "mask_offset" in det:
                y_off, x_off, h, w = det["mask_offset"]
                cropped = mask[:h, :w].astype(np.uint8)
            else:
                # Use pixel bbox to crop (avoids full-image scan)
                c0, r0, c1, r1 = bx
                r0i = max(0, int(r0))
                c0i = max(0, int(c0))
                r1i = min(mask.shape[0], int(np.ceil(r1)))
                c1i = min(mask.shape[1], int(np.ceil(c1)))
                cropped = mask[r0i:r1i, c0i:c1i].astype(np.uint8)
                y_off, x_off = r0i, c0i

            if cropped.any():
                crop_transform = transform * Affine.translation(x_off, y_off)

                # Collect all polygon components and union them
                parts = []
                for geom_dict, value in rasterio_shapes(
                    cropped, mask=cropped, transform=crop_transform
                ):
                    if value == 1:
                        candidate = shape(geom_dict)
                        if candidate.is_valid and not candidate.is_empty:
                            parts.append(candidate)
                if parts:
                    merged = unary_union(parts) if len(parts) > 1 else parts[0]
                    if simplify_tolerance > 0:
                        simplified = merged.simplify(
                            simplify_tolerance, preserve_topology=True
                        )
                        merged = simplified if simplified.is_valid else merged
                    geom = merged

        # Fallback to bounding box geometry
        if geom is None:
            c0, r0, c1, r1 = bx  # (xmin, ymin, xmax, ymax) in pixel coords
            pts = []
            for c, r in [(c0, r0), (c1, r0), (c1, r1), (c0, r1)]:
                x, y = transform * (c, r)
                pts.append((x, y))
            geom = Polygon(pts)

        name = "unknown"
        if class_names and label < len(class_names):
            name = class_names[label]
        elif class_names:
            name = f"class_{label}"

        records.append(
            {
                "geometry": geom,
                "class_id": label,
                "class_name": name,
                "score": score,
                "instance_id": instance_id,
                "area_pixels": area_pixels,
            }
        )

    gdf = gpd.GeoDataFrame(records, crs=crs)
    return gdf

download_nwpu_vhr10(output_dir='NWPU-VHR-10', overwrite=False)

Download and extract the NWPU-VHR-10 dataset.

The NWPU-VHR-10 dataset contains 800 VHR (Very High Resolution) remote sensing images with 10 object classes: airplane, ship, storage_tank, baseball_diamond, tennis_court, basketball_court, ground_track_field, harbor, bridge, and vehicle. It has 3,775 annotated instances in COCO format (bounding boxes and instance segmentation masks).

Parameters:

Name Type Description Default
output_dir str

Path for the downloaded ZIP file and extracted dataset directory. Defaults to "NWPU-VHR-10".

'NWPU-VHR-10'
overwrite bool

Whether to overwrite existing files. Defaults to False.

False

Returns:

Name Type Description
str str

Path to the extracted dataset directory.

Source code in geoai/object_detect.py
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
def download_nwpu_vhr10(
    output_dir: str = "NWPU-VHR-10",
    overwrite: bool = False,
) -> str:
    """Download and extract the NWPU-VHR-10 dataset.

    The NWPU-VHR-10 dataset contains 800 VHR (Very High Resolution) remote
    sensing images with 10 object classes: airplane, ship, storage_tank,
    baseball_diamond, tennis_court, basketball_court, ground_track_field,
    harbor, bridge, and vehicle. It has 3,775 annotated instances in COCO
    format (bounding boxes and instance segmentation masks).

    Args:
        output_dir (str): Path for the downloaded ZIP file and extracted
            dataset directory. Defaults to "NWPU-VHR-10".
        overwrite (bool): Whether to overwrite existing files. Defaults to False.

    Returns:
        str: Path to the extracted dataset directory.
    """
    zip_path = output_dir + ".zip"
    data_path = download_file(NWPU_VHR10_URL, output_path=zip_path, overwrite=overwrite)
    return data_path

download_nwpu_vhr10_model(repo_id=NWPU_VHR10_HF_REPO, filename=NWPU_VHR10_HF_FILENAME)

Download the pretrained NWPU-VHR-10 Mask R-CNN model from HuggingFace Hub.

Downloads a Mask R-CNN (ResNet-50 FPN) model trained on the NWPU-VHR-10 dataset for 10-class object detection on remote sensing imagery.

The model achieves the following performance on the validation set
  • mAP@0.5: 0.709
  • mAP@0.75: 0.518
  • mAP@[0.5:0.95]: 0.459

Parameters:

Name Type Description Default
repo_id str

HuggingFace Hub repository ID. Defaults to "giswqs/nwpu-vhr10-maskrcnn".

NWPU_VHR10_HF_REPO
filename str

Model filename in the repository. Defaults to "best_model.pth".

NWPU_VHR10_HF_FILENAME

Returns:

Name Type Description
str str

Local path to the downloaded model weights file.

Example

import geoai model_path = geoai.download_nwpu_vhr10_model() print(model_path) # local cache path

Source code in geoai/object_detect.py
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
def download_nwpu_vhr10_model(
    repo_id: str = NWPU_VHR10_HF_REPO,
    filename: str = NWPU_VHR10_HF_FILENAME,
) -> str:
    """Download the pretrained NWPU-VHR-10 Mask R-CNN model from HuggingFace Hub.

    Downloads a Mask R-CNN (ResNet-50 FPN) model trained on the NWPU-VHR-10
    dataset for 10-class object detection on remote sensing imagery.

    The model achieves the following performance on the validation set:
        - mAP@0.5: 0.709
        - mAP@0.75: 0.518
        - mAP@[0.5:0.95]: 0.459

    Args:
        repo_id (str): HuggingFace Hub repository ID.
            Defaults to "giswqs/nwpu-vhr10-maskrcnn".
        filename (str): Model filename in the repository.
            Defaults to "best_model.pth".

    Returns:
        str: Local path to the downloaded model weights file.

    Example:
        >>> import geoai
        >>> model_path = geoai.download_nwpu_vhr10_model()
        >>> print(model_path)  # local cache path
    """
    from huggingface_hub import hf_hub_download

    model_path = hf_hub_download(repo_id=repo_id, filename=filename)
    logger.info(f"Model downloaded to: {model_path}")
    return model_path

evaluate_multiclass_detector(model_path=None, model_name=None, images_dir='', annotations_path='', num_classes=11, class_names=None, num_channels=3, batch_size=4, device=None, num_workers=None, repo_id=None, verbose=True)

Evaluate a trained multi-class detection model on a dataset.

Loads a trained model and computes COCO-style mAP metrics on the provided dataset.

If model_path is None, the pretrained NWPU-VHR-10 model is automatically downloaded from HuggingFace Hub.

Parameters:

Name Type Description Default
model_path str

Path to trained model weights. If None, downloads the pretrained NWPU-VHR-10 model.

None
model_name str

Detection model architecture name. If None, auto-detected from sidecar or defaults to "maskrcnn_resnet50_fpn".

None
images_dir str

Directory containing evaluation images.

''
annotations_path str

Path to COCO-format annotations JSON.

''
num_classes int

Number of classes including background. Defaults to 11 (NWPU-VHR-10).

11
class_names list

List of class names (excluding background).

None
num_channels int

Number of image channels. Defaults to 3.

3
batch_size int

Evaluation batch size. Defaults to 4.

4
device device

Compute device.

None
num_workers int

Number of data loading workers.

None
repo_id str

HuggingFace Hub repository ID for downloading the model. Defaults to "giswqs/nwpu-vhr10-maskrcnn".

None
verbose bool

Whether to print results. Defaults to True.

True

Returns:

Type Description
Dict[str, float]

Dict with mAP metrics.

Source code in geoai/object_detect.py
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
def evaluate_multiclass_detector(
    model_path: Optional[str] = None,
    model_name: Optional[str] = None,
    images_dir: str = "",
    annotations_path: str = "",
    num_classes: int = 11,
    class_names: Optional[List[str]] = None,
    num_channels: int = 3,
    batch_size: int = 4,
    device: Optional[torch.device] = None,
    num_workers: Optional[int] = None,
    repo_id: Optional[str] = None,
    verbose: bool = True,
) -> Dict[str, float]:
    """Evaluate a trained multi-class detection model on a dataset.

    Loads a trained model and computes COCO-style mAP metrics on the
    provided dataset.

    If ``model_path`` is None, the pretrained NWPU-VHR-10 model is
    automatically downloaded from HuggingFace Hub.

    Args:
        model_path (str, optional): Path to trained model weights.
            If None, downloads the pretrained NWPU-VHR-10 model.
        model_name (str, optional): Detection model architecture name.
            If None, auto-detected from sidecar or defaults to
            ``"maskrcnn_resnet50_fpn"``.
        images_dir (str): Directory containing evaluation images.
        annotations_path (str): Path to COCO-format annotations JSON.
        num_classes (int): Number of classes including background.
            Defaults to 11 (NWPU-VHR-10).
        class_names (list, optional): List of class names (excluding background).
        num_channels (int): Number of image channels. Defaults to 3.
        batch_size (int): Evaluation batch size. Defaults to 4.
        device (torch.device, optional): Compute device.
        num_workers (int, optional): Number of data loading workers.
        repo_id (str, optional): HuggingFace Hub repository ID for
            downloading the model. Defaults to ``"giswqs/nwpu-vhr10-maskrcnn"``.
        verbose (bool): Whether to print results. Defaults to True.

    Returns:
        Dict with mAP metrics.
    """
    import platform

    from torch.utils.data import DataLoader

    if device is None:
        device = get_device()

    # Handle pretrained model download
    if model_path is None:
        hf_repo = repo_id or NWPU_VHR10_HF_REPO
        model_path = download_nwpu_vhr10_model(
            repo_id=hf_repo, filename=NWPU_VHR10_HF_FILENAME
        )
        num_classes = len(NWPU_VHR10_CLASSES)
        if class_names is None:
            class_names = NWPU_VHR10_CLASSES[1:]  # Exclude background
        if model_name is None:
            model_name = "maskrcnn_resnet50_fpn"

    # Try to read model_name from sidecar
    if model_name is None:
        class_info_path = os.path.join(os.path.dirname(model_path), "class_info.json")
        if os.path.exists(class_info_path):
            with open(class_info_path, "r") as f:
                class_info = json.load(f)
            model_name = class_info.get("model_name", None)
    if model_name is None:
        model_name = "maskrcnn_resnet50_fpn"

    # Load model
    model = get_detection_model(
        model_name=model_name,
        num_classes=num_classes,
        num_channels=num_channels,
        pretrained=False,
    )

    if not os.path.exists(model_path):
        hf_repo = repo_id or NWPU_VHR10_HF_REPO
        from huggingface_hub import hf_hub_download

        model_path = hf_hub_download(repo_id=hf_repo, filename=model_path)

    state_dict = torch.load(model_path, map_location=device)
    if any(key.startswith("module.") for key in state_dict.keys()):
        state_dict = {
            key.replace("module.", ""): value for key, value in state_dict.items()
        }
    model.load_state_dict(state_dict)
    model.to(device)

    # Create dataset and loader
    dataset = COCODetectionDataset(
        coco_json_path=annotations_path,
        images_dir=images_dir,
        transforms=get_transform(train=False),
        num_channels=num_channels,
        compute_masks=model_has_masks(model_name),
    )

    if num_workers is None:
        num_workers = 0 if platform.system() in ["Darwin", "Windows"] else 4

    data_loader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=False,
        collate_fn=collate_fn,
        num_workers=num_workers,
    )

    # Evaluate
    results = evaluate_coco_metrics(
        model=model,
        data_loader=data_loader,
        device=device,
        class_names=class_names,
        verbose=verbose,
    )

    return results

multiclass_detection(input_path, output_path, model_path=None, model_name=None, num_classes=11, class_names=None, window_size=512, overlap=256, confidence_threshold=0.5, nms_threshold=0.3, batch_size=4, num_channels=3, device=None, repo_id=None, **kwargs)

Perform multi-class object detection on a GeoTIFF or image.

Loads a trained detection model and runs inference using a sliding window approach. Outputs a 2-band raster with class labels and instance IDs.

If model_path is None, the pretrained NWPU-VHR-10 model is automatically downloaded from HuggingFace Hub with default num_classes=11 and class_names set to NWPU-VHR-10 classes.

Parameters:

Name Type Description Default
input_path str

Path to input image (GeoTIFF, JPEG, PNG, etc.).

required
output_path str

Path to save output raster.

required
model_path str

Path to trained model weights (.pth file). If None, downloads the pretrained NWPU-VHR-10 model from HuggingFace Hub.

None
model_name str

Detection model architecture name. If None, auto-detected from class_info.json sidecar or checkpoint keys. Falls back to "maskrcnn_resnet50_fpn" for backward compatibility.

None
num_classes int

Number of classes including background. Defaults to 11 (NWPU-VHR-10).

11
class_names list

List of class names (index 0 = background).

None
window_size int

Sliding window size. Defaults to 512.

512
overlap int

Window overlap in pixels. Defaults to 256.

256
confidence_threshold float

Minimum detection score. Defaults to 0.5.

0.5
nms_threshold float

IoU threshold for NMS. Defaults to 0.3.

0.3
batch_size int

Inference batch size. Defaults to 4.

4
num_channels int

Number of input image channels. Defaults to 3.

3
device device

Compute device.

None
repo_id str

HuggingFace Hub repository ID for downloading the model. Defaults to "giswqs/nwpu-vhr10-maskrcnn".

None
**kwargs Any

Additional keyword arguments.

{}

Returns:

Type Description
str

Tuple of (output_path, inference_time, detections_list) where each

float

detection is a dict with mask, score, box, and label.

Source code in geoai/object_detect.py
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
def multiclass_detection(
    input_path: str,
    output_path: str,
    model_path: Optional[str] = None,
    model_name: Optional[str] = None,
    num_classes: int = 11,
    class_names: Optional[List[str]] = None,
    window_size: int = 512,
    overlap: int = 256,
    confidence_threshold: float = 0.5,
    nms_threshold: float = 0.3,
    batch_size: int = 4,
    num_channels: int = 3,
    device: Optional[torch.device] = None,
    repo_id: Optional[str] = None,
    **kwargs: Any,
) -> Tuple[str, float, List[Dict]]:
    """Perform multi-class object detection on a GeoTIFF or image.

    Loads a trained detection model and runs inference using a sliding
    window approach. Outputs a 2-band raster with class labels and
    instance IDs.

    If ``model_path`` is None, the pretrained NWPU-VHR-10 model is
    automatically downloaded from HuggingFace Hub with default
    ``num_classes=11`` and ``class_names`` set to NWPU-VHR-10 classes.

    Args:
        input_path (str): Path to input image (GeoTIFF, JPEG, PNG, etc.).
        output_path (str): Path to save output raster.
        model_path (str, optional): Path to trained model weights (.pth file).
            If None, downloads the pretrained NWPU-VHR-10 model from
            HuggingFace Hub.
        model_name (str, optional): Detection model architecture name. If
            None, auto-detected from ``class_info.json`` sidecar or
            checkpoint keys. Falls back to ``"maskrcnn_resnet50_fpn"``
            for backward compatibility.
        num_classes (int): Number of classes including background. Defaults
            to 11 (NWPU-VHR-10).
        class_names (list, optional): List of class names (index 0 = background).
        window_size (int): Sliding window size. Defaults to 512.
        overlap (int): Window overlap in pixels. Defaults to 256.
        confidence_threshold (float): Minimum detection score. Defaults to 0.5.
        nms_threshold (float): IoU threshold for NMS. Defaults to 0.3.
        batch_size (int): Inference batch size. Defaults to 4.
        num_channels (int): Number of input image channels. Defaults to 3.
        device (torch.device, optional): Compute device.
        repo_id (str, optional): HuggingFace Hub repository ID for
            downloading the model. Defaults to ``"giswqs/nwpu-vhr10-maskrcnn"``.
        **kwargs: Additional keyword arguments.

    Returns:
        Tuple of (output_path, inference_time, detections_list) where each
        detection is a dict with mask, score, box, and label.
    """
    import rasterio

    if device is None:
        device = get_device()

    # Handle pretrained model download
    use_pretrained = model_path is None
    if use_pretrained:
        hf_repo = repo_id or NWPU_VHR10_HF_REPO
        model_path = download_nwpu_vhr10_model(
            repo_id=hf_repo, filename=NWPU_VHR10_HF_FILENAME
        )
        if class_names is None:
            class_names = NWPU_VHR10_CLASSES
        num_classes = len(NWPU_VHR10_CLASSES)
        if model_name is None:
            model_name = "maskrcnn_resnet50_fpn"

    # Convert non-GeoTIFF images to temporary GeoTIFF for processing
    temp_tif = None
    if not input_path.lower().endswith((".tif", ".tiff")):
        from PIL import Image as PILImage

        img = PILImage.open(input_path).convert("RGB")
        img_array = np.array(img)
        h, w = img_array.shape[:2]

        temp_tif = output_path.replace(
            os.path.splitext(output_path)[1], "_temp_input.tif"
        )
        profile = {
            "driver": "GTiff",
            "dtype": "uint8",
            "width": w,
            "height": h,
            "count": img_array.shape[2] if img_array.ndim == 3 else 1,
            "crs": None,
            "transform": rasterio.transform.from_bounds(0, 0, w, h, w, h),
        }
        with rasterio.open(temp_tif, "w", **profile) as dst:
            if img_array.ndim == 3:
                for band in range(img_array.shape[2]):
                    dst.write(img_array[:, :, band], band + 1)
            else:
                dst.write(img_array, 1)
        input_path = temp_tif

    # Resolve model path (download if needed)
    if not os.path.exists(model_path):
        hf_repo = repo_id or NWPU_VHR10_HF_REPO
        from huggingface_hub import hf_hub_download

        model_path = hf_hub_download(repo_id=hf_repo, filename=model_path)

    # Try to load class_info.json sidecar
    class_info_path = os.path.join(os.path.dirname(model_path), "class_info.json")
    if not use_pretrained and os.path.exists(class_info_path):
        with open(class_info_path, "r") as f:
            class_info = json.load(f)
        num_classes = class_info.get("num_classes", num_classes)
        if class_names is None:
            class_names = class_info.get("class_names", class_names)
        if model_name is None:
            model_name = class_info.get("model_name", None)

    # Load checkpoint
    state_dict = torch.load(model_path, map_location=device)
    if any(key.startswith("module.") for key in state_dict.keys()):
        state_dict = {
            key.replace("module.", ""): value for key, value in state_dict.items()
        }

    # Auto-detect model_name from checkpoint keys if not set
    if model_name is None:
        if "roi_heads.mask_predictor.conv5_mask.weight" in state_dict:
            model_name = "maskrcnn_resnet50_fpn"
        elif "roi_heads.box_predictor.cls_score.weight" in state_dict:
            model_name = "fasterrcnn_resnet50_fpn_v2"
        elif "head.classification_head.cls_logits.weight" in state_dict:
            # Distinguish FCOS (anchor-free) from RetinaNet (anchor-based)
            # by checking for anchor_generator keys
            if any(k.startswith("anchor_generator.") for k in state_dict):
                model_name = "retinanet_resnet50_fpn_v2"
            else:
                model_name = "fcos_resnet50_fpn"
        else:
            model_name = "maskrcnn_resnet50_fpn"

    # Infer num_classes from checkpoint
    if not use_pretrained:
        rcnn_cls_key = "roi_heads.box_predictor.cls_score.weight"
        retina_cls_key = "head.classification_head.cls_logits.weight"
        if rcnn_cls_key in state_dict:
            inferred = state_dict[rcnn_cls_key].shape[0]
            if inferred != num_classes:
                num_classes = inferred
        elif retina_cls_key in state_dict:
            # For RetinaNet/FCOS: out_channels = num_anchors * num_classes
            # num_anchors is 9 for RetinaNet, 1 for FCOS
            out_channels = state_dict[retina_cls_key].shape[0]
            num_anchors = 9 if model_name == "retinanet_resnet50_fpn_v2" else 1
            inferred = out_channels // num_anchors
            if inferred != num_classes:
                num_classes = inferred

    # Load model
    model = get_detection_model(
        model_name=model_name,
        num_classes=num_classes,
        num_channels=num_channels,
        pretrained=False,
    )
    model.load_state_dict(state_dict)

    result = multiclass_detection_inference_on_geotiff(
        model=model,
        geotiff_path=input_path,
        output_path=output_path,
        class_names=class_names,
        window_size=window_size,
        overlap=overlap,
        confidence_threshold=confidence_threshold,
        nms_threshold=nms_threshold,
        batch_size=batch_size,
        num_channels=num_channels,
        device=device,
        **kwargs,
    )

    # Clean up temporary file
    if temp_tif and os.path.exists(temp_tif):
        os.remove(temp_tif)

    return result

plot_detection_training_history(history_path, figsize=(15, 4), output_path=None)

Plot training metrics from a detection model training history file.

Loads a training_history.pth file saved during :func:train_multiclass_detector and plots up to three subplots: training/validation loss, validation IoU, and learning rate schedule. Subplots are skipped if the corresponding keys are missing.

Parameters:

Name Type Description Default
history_path str

Path to the training_history.pth file.

required
figsize tuple

Figure size (width, height). Defaults to (15, 4).

(15, 4)
output_path str

Path to save the figure. If None, displays interactively.

None
Source code in geoai/object_detect.py
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
def plot_detection_training_history(
    history_path: str,
    figsize: Tuple[int, int] = (15, 4),
    output_path: Optional[str] = None,
) -> None:
    """Plot training metrics from a detection model training history file.

    Loads a ``training_history.pth`` file saved during
    :func:`train_multiclass_detector` and plots up to three subplots:
    training/validation loss, validation IoU, and learning rate schedule.
    Subplots are skipped if the corresponding keys are missing.

    Args:
        history_path (str): Path to the ``training_history.pth`` file.
        figsize (tuple): Figure size (width, height). Defaults to (15, 4).
        output_path (str, optional): Path to save the figure. If None,
            displays interactively.
    """
    if not os.path.exists(history_path):
        logger.warning(f"Training history not found: {history_path}")
        return

    history = torch.load(history_path, weights_only=True)
    epochs = history.get("epochs", [])

    panels = []
    if "train_loss" in history:
        panels.append("loss")
    if "val_iou" in history:
        panels.append("iou")
    if "lr" in history:
        panels.append("lr")

    if not panels:
        logger.warning("No plottable metrics found in training history.")
        return

    from matplotlib.ticker import MaxNLocator

    fig, axes = plt.subplots(1, len(panels), figsize=figsize)
    if len(panels) == 1:
        axes = [axes]

    panel_idx = 0
    if "loss" in panels:
        ax = axes[panel_idx]
        ax.plot(epochs, history["train_loss"], label="Train Loss")
        has_val_loss = "val_loss" in history and any(
            math.isfinite(v) for v in history["val_loss"]
        )
        if has_val_loss:
            ax.plot(epochs, history["val_loss"], label="Val Loss")
        ax.set_xlabel("Epoch")
        ax.set_ylabel("Loss")
        ax.set_title("Training & Validation Loss" if has_val_loss else "Training Loss")
        ax.xaxis.set_major_locator(MaxNLocator(integer=True))
        ax.legend()
        panel_idx += 1

    if "iou" in panels:
        ax = axes[panel_idx]
        ax.plot(epochs, history["val_iou"], label="Val IoU", color="green")
        ax.set_xlabel("Epoch")
        ax.set_ylabel("IoU")
        ax.set_title("Validation IoU")
        ax.xaxis.set_major_locator(MaxNLocator(integer=True))
        ax.legend()
        panel_idx += 1

    if "lr" in panels:
        ax = axes[panel_idx]
        ax.plot(epochs, history["lr"], label="Learning Rate", color="orange")
        ax.set_xlabel("Epoch")
        ax.set_ylabel("LR")
        ax.set_title("Learning Rate Schedule")
        ax.xaxis.set_major_locator(MaxNLocator(integer=True))
        ax.legend()
        panel_idx += 1

    plt.tight_layout()

    if output_path:
        plt.savefig(output_path, dpi=150, bbox_inches="tight")
        plt.close()
    else:
        plt.show()

predict_detector_from_hub(input_path, output_path, repo_id, window_size=512, overlap=256, confidence_threshold=0.5, nms_threshold=0.3, batch_size=4, device=None, token=None, **kwargs)

Run object detection using a model downloaded from Hugging Face Hub.

Downloads model.pth and config.json from the specified Hub repository and delegates to :func:multiclass_detection for inference.

Parameters:

Name Type Description Default
input_path str

Path to input image (GeoTIFF, JPEG, PNG, etc.).

required
output_path str

Path to save output raster.

required
repo_id str

Hub repository in "username/repo-name" format.

required
window_size int

Sliding window size. Defaults to 512.

512
overlap int

Window overlap in pixels. Defaults to 256.

256
confidence_threshold float

Minimum detection score. Defaults to 0.5.

0.5
nms_threshold float

IoU threshold for NMS. Defaults to 0.3.

0.3
batch_size int

Inference batch size. Defaults to 4.

4
device device

Compute device.

None
token str

Hugging Face API token for private repositories. If None, the token stored by huggingface-cli login is used.

None
**kwargs Any

Additional keyword arguments passed to :func:multiclass_detection.

{}

Returns:

Type Description
Optional[Tuple[str, float, List[Dict]]]

Tuple of (output_path, inference_time, detections_list) where each

Optional[Tuple[str, float, List[Dict]]]

detection is a dict with mask, score, box, and label, or None

Optional[Tuple[str, float, List[Dict]]]

if huggingface_hub is not installed.

Source code in geoai/object_detect.py
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
def predict_detector_from_hub(
    input_path: str,
    output_path: str,
    repo_id: str,
    window_size: int = 512,
    overlap: int = 256,
    confidence_threshold: float = 0.5,
    nms_threshold: float = 0.3,
    batch_size: int = 4,
    device: Optional[torch.device] = None,
    token: Optional[str] = None,
    **kwargs: Any,
) -> Optional[Tuple[str, float, List[Dict]]]:
    """Run object detection using a model downloaded from Hugging Face Hub.

    Downloads ``model.pth`` and ``config.json`` from the specified Hub
    repository and delegates to :func:`multiclass_detection` for inference.

    Args:
        input_path (str): Path to input image (GeoTIFF, JPEG, PNG, etc.).
        output_path (str): Path to save output raster.
        repo_id (str): Hub repository in ``"username/repo-name"`` format.
        window_size (int): Sliding window size. Defaults to 512.
        overlap (int): Window overlap in pixels. Defaults to 256.
        confidence_threshold (float): Minimum detection score. Defaults to
            0.5.
        nms_threshold (float): IoU threshold for NMS. Defaults to 0.3.
        batch_size (int): Inference batch size. Defaults to 4.
        device (torch.device, optional): Compute device.
        token (str, optional): Hugging Face API token for private
            repositories. If None, the token stored by
            ``huggingface-cli login`` is used.
        **kwargs: Additional keyword arguments passed to
            :func:`multiclass_detection`.

    Returns:
        Tuple of (output_path, inference_time, detections_list) where each
        detection is a dict with mask, score, box, and label, or None
        if ``huggingface_hub`` is not installed.
    """
    try:
        from huggingface_hub import hf_hub_download
    except ImportError:
        logger.error(
            "huggingface_hub is required. "
            "Install it with: pip install huggingface-hub"
        )
        return None

    try:
        logger.info(f"Downloading model from {repo_id}...")
        model_file = hf_hub_download(repo_id=repo_id, filename="model.pth", token=token)
        config_file = hf_hub_download(
            repo_id=repo_id, filename="config.json", token=token
        )
    except Exception as e:
        logger.error(f"Failed to download model from Hub: {e}")
        return None

    with open(config_file) as f:
        config = json.load(f)

    num_classes = config.get("num_classes", 11)
    num_channels = config.get("num_channels", 3)
    class_names = config.get("class_names", None)
    detected_model_name = config.get("model_name", "maskrcnn_resnet50_fpn")

    return multiclass_detection(
        input_path=input_path,
        output_path=output_path,
        model_path=model_file,
        model_name=detected_model_name,
        num_classes=num_classes,
        class_names=class_names,
        window_size=window_size,
        overlap=overlap,
        confidence_threshold=confidence_threshold,
        nms_threshold=nms_threshold,
        batch_size=batch_size,
        num_channels=num_channels,
        device=device,
        **kwargs,
    )

prepare_nwpu_vhr10(data_dir, output_dir=None, val_split=0.2, seed=42)

Prepare NWPU-VHR-10 dataset for training.

Converts the original text-based annotations to COCO JSON format, then splits the dataset into train/val sets. The original dataset uses text files with (x1,y1),(x2,y2),class_id per line for bounding boxes.

Note: Only images with at least one annotation are included in the train/val splits. The 150 "negative" images in the NWPU-VHR-10 dataset (those without any target objects) are excluded from the splits.

Parameters:

Name Type Description Default
data_dir str

Path to the extracted NWPU-VHR-10 directory.

required
output_dir str

Output directory for organized data. If None, creates files alongside the original data.

None
val_split float

Fraction of data for validation. Defaults to 0.2.

0.2
seed int

Random seed for reproducibility. Defaults to 42.

42

Returns:

Type Description
Dict[str, Any]

Dict with keys: - 'images_dir': Path to images directory - 'annotations_path': Path to the full annotations JSON - 'train_annotations': Path to train split annotations JSON - 'val_annotations': Path to val split annotations JSON - 'train_image_ids': List of training image IDs - 'val_image_ids': List of validation image IDs - 'class_names': List of class names (including background) - 'num_classes': Number of classes (including background)

Source code in geoai/object_detect.py
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
def prepare_nwpu_vhr10(
    data_dir: str,
    output_dir: Optional[str] = None,
    val_split: float = 0.2,
    seed: int = 42,
) -> Dict[str, Any]:
    """Prepare NWPU-VHR-10 dataset for training.

    Converts the original text-based annotations to COCO JSON format,
    then splits the dataset into train/val sets. The original dataset uses
    text files with ``(x1,y1),(x2,y2),class_id`` per line for bounding boxes.

    Note: Only images with at least one annotation are included in the
    train/val splits. The 150 "negative" images in the NWPU-VHR-10 dataset
    (those without any target objects) are excluded from the splits.

    Args:
        data_dir (str): Path to the extracted NWPU-VHR-10 directory.
        output_dir (str, optional): Output directory for organized data.
            If None, creates files alongside the original data.
        val_split (float): Fraction of data for validation. Defaults to 0.2.
        seed (int): Random seed for reproducibility. Defaults to 42.

    Returns:
        Dict with keys:
            - 'images_dir': Path to images directory
            - 'annotations_path': Path to the full annotations JSON
            - 'train_annotations': Path to train split annotations JSON
            - 'val_annotations': Path to val split annotations JSON
            - 'train_image_ids': List of training image IDs
            - 'val_image_ids': List of validation image IDs
            - 'class_names': List of class names (including background)
            - 'num_classes': Number of classes (including background)
    """
    from sklearn.model_selection import train_test_split

    if output_dir is None:
        output_dir = data_dir

    os.makedirs(output_dir, exist_ok=True)

    # Handle nested directory structure (NWPU-VHR-10/NWPU-VHR-10/)
    actual_dir = data_dir
    nested = os.path.join(data_dir, "NWPU-VHR-10")
    if os.path.isdir(nested):
        actual_dir = nested

    # Find images directory
    images_dir = None
    for candidate in ["positive image set", "positive_image_set", "images"]:
        path = os.path.join(actual_dir, candidate)
        if os.path.isdir(path):
            images_dir = path
            break

    if images_dir is None:
        raise FileNotFoundError(
            f"Could not find images directory in {actual_dir}. "
            "Expected 'positive image set' directory."
        )

    # Find ground truth directory
    gt_dir = None
    for candidate in ["ground truth", "ground_truth", "annotations", "labels"]:
        path = os.path.join(actual_dir, candidate)
        if os.path.isdir(path):
            gt_dir = path
            break

    # Check if COCO JSON already exists
    annotations_path = os.path.join(output_dir, "annotations.json")
    if os.path.exists(annotations_path):
        logger.info(f"Using existing COCO annotations: {annotations_path}")
    elif gt_dir is not None:
        # Convert text annotations to COCO JSON
        logger.info("Converting NWPU-VHR-10 text annotations to COCO JSON format...")
        _convert_nwpu_to_coco(images_dir, gt_dir, annotations_path)
        logger.info(f"COCO annotations saved to: {annotations_path}")
    else:
        # Look for existing JSON annotations
        for candidate_ann in [
            "annotations.json",
            "instances.json",
        ]:
            path = os.path.join(actual_dir, candidate_ann)
            if os.path.isfile(path):
                annotations_path = path
                break

    if not os.path.exists(annotations_path):
        raise FileNotFoundError(f"Could not find or create annotations for {data_dir}.")

    # Load annotations
    with open(annotations_path, "r") as f:
        coco_data = json.load(f)

    # Get image IDs that have annotations
    annotated_image_ids = set()
    for ann in coco_data.get("annotations", []):
        annotated_image_ids.add(ann["image_id"])

    image_ids = sorted(list(annotated_image_ids))

    # Split into train and validation
    train_ids, val_ids = train_test_split(
        image_ids, test_size=val_split, random_state=seed
    )

    # Create split annotation files
    train_ann_path = os.path.join(output_dir, "train_annotations.json")
    val_ann_path = os.path.join(output_dir, "val_annotations.json")

    train_ids_set = set(train_ids)
    val_ids_set = set(val_ids)

    all_images = {img["id"]: img for img in coco_data["images"]}

    train_coco = {
        "images": [all_images[img_id] for img_id in train_ids if img_id in all_images],
        "annotations": [
            ann for ann in coco_data["annotations"] if ann["image_id"] in train_ids_set
        ],
        "categories": coco_data.get("categories", []),
    }

    val_coco = {
        "images": [all_images[img_id] for img_id in val_ids if img_id in all_images],
        "annotations": [
            ann for ann in coco_data["annotations"] if ann["image_id"] in val_ids_set
        ],
        "categories": coco_data.get("categories", []),
    }

    with open(train_ann_path, "w") as f:
        json.dump(train_coco, f)

    with open(val_ann_path, "w") as f:
        json.dump(val_coco, f)

    class_names = NWPU_VHR10_CLASSES

    logger.info(f"Dataset prepared:")
    logger.info(f"  Images directory: {images_dir}")
    logger.info(f"  Total annotated images: {len(image_ids)}")
    logger.info(f"  Total annotations: {len(coco_data['annotations'])}")
    logger.info(f"  Training images: {len(train_ids)}")
    logger.info(f"  Validation images: {len(val_ids)}")
    logger.info(f"  Classes: {class_names[1:]}")

    return {
        "images_dir": images_dir,
        "annotations_path": annotations_path,
        "train_annotations": train_ann_path,
        "val_annotations": val_ann_path,
        "train_image_ids": train_ids,
        "val_image_ids": val_ids,
        "class_names": class_names,
        "num_classes": len(class_names),
    }

push_detector_to_hub(model_path, repo_id, model_name='fasterrcnn_resnet50_fpn_v2', num_classes=11, num_channels=3, class_names=None, commit_message=None, private=False, token=None)

Push a trained detection model to Hugging Face Hub.

Uploads the model weights (model.pth) and a config.json file containing model metadata to the specified Hub repository. The repository is created automatically if it does not already exist.

Parameters:

Name Type Description Default
model_path str

Path to the trained model weights (.pth file).

required
repo_id str

Hub repository in "username/repo-name" format.

required
model_name str

Detection model architecture name. Stored in config.json so the model can be reconstructed on download. Defaults to "fasterrcnn_resnet50_fpn_v2".

'fasterrcnn_resnet50_fpn_v2'
num_classes int

Number of classes including background. Defaults to 11.

11
num_channels int

Number of input image channels. Defaults to 3.

3
class_names list of str

Ordered list of class name strings (index 0 should be "background"). Stored in config.json so downstream users do not need the original dataset.

None
commit_message str

Commit message for the Hub upload. Defaults to a descriptive string including model_name.

None
private bool

Whether to create a private repository. Defaults to False.

False
token str

Hugging Face API token with write access. If None, the token stored by huggingface-cli login is used.

None

Returns:

Name Type Description
str Optional[str]

URL of the uploaded repository on Hugging Face Hub, or None

Optional[str]

if huggingface_hub is not installed.

Source code in geoai/object_detect.py
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
def push_detector_to_hub(
    model_path: str,
    repo_id: str,
    model_name: str = "fasterrcnn_resnet50_fpn_v2",
    num_classes: int = 11,
    num_channels: int = 3,
    class_names: Optional[List[str]] = None,
    commit_message: Optional[str] = None,
    private: bool = False,
    token: Optional[str] = None,
) -> Optional[str]:
    """Push a trained detection model to Hugging Face Hub.

    Uploads the model weights (``model.pth``) and a ``config.json`` file
    containing model metadata to the specified Hub repository. The
    repository is created automatically if it does not already exist.

    Args:
        model_path (str): Path to the trained model weights (``.pth`` file).
        repo_id (str): Hub repository in ``"username/repo-name"`` format.
        model_name (str): Detection model architecture name. Stored in
            ``config.json`` so the model can be reconstructed on download.
            Defaults to ``"fasterrcnn_resnet50_fpn_v2"``.
        num_classes (int): Number of classes including background. Defaults
            to 11.
        num_channels (int): Number of input image channels. Defaults to 3.
        class_names (list of str, optional): Ordered list of class name
            strings (index 0 should be ``"background"``). Stored in
            ``config.json`` so downstream users do not need the original
            dataset.
        commit_message (str, optional): Commit message for the Hub upload.
            Defaults to a descriptive string including ``model_name``.
        private (bool): Whether to create a private repository. Defaults to
            False.
        token (str, optional): Hugging Face API token with write access. If
            None, the token stored by ``huggingface-cli login`` is used.

    Returns:
        str: URL of the uploaded repository on Hugging Face Hub, or None
        if ``huggingface_hub`` is not installed.
    """
    try:
        from huggingface_hub import HfApi, create_repo
    except ImportError:
        logger.error(
            "huggingface_hub is required to push models. "
            "Install it with: pip install huggingface-hub"
        )
        return None

    # Load state dict
    state_dict = torch.load(model_path, map_location="cpu")
    if any(key.startswith("module.") for key in state_dict.keys()):
        state_dict = {
            key.replace("module.", ""): value for key, value in state_dict.items()
        }

    # Build configuration dict
    config: Dict[str, Any] = {
        "model_type": "detection",
        "model_name": model_name,
        "num_classes": num_classes,
        "num_channels": num_channels,
        "class_names": class_names,
    }

    try:
        # Create Hub repository (no-op if it already exists)
        api = HfApi(token=token)
        create_repo(repo_id, private=private, token=token, exist_ok=True)

        if commit_message is None:
            commit_message = f"Upload {model_name} object detection model"

        with tempfile.TemporaryDirectory() as tmpdir:
            model_save_path = os.path.join(tmpdir, "model.pth")
            torch.save(state_dict, model_save_path)

            config_path = os.path.join(tmpdir, "config.json")
            with open(config_path, "w") as f:
                json.dump(config, f, indent=2)

            api.upload_folder(
                folder_path=tmpdir,
                repo_id=repo_id,
                commit_message=commit_message,
                token=token,
            )

        url = f"https://huggingface.co/{repo_id}"
        logger.info(f"Model successfully pushed to: {url}")
        return url
    except Exception as e:
        logger.error(f"Failed to push model to Hub: {e}")
        return None

train_multiclass_detector(images_dir, annotations_path, output_dir, model_name='fasterrcnn_resnet50_fpn_v2', class_names=None, num_channels=3, batch_size=4, num_epochs=50, learning_rate=0.005, val_split=0.2, seed=42, pretrained=True, pretrained_model_path=None, device=None, num_workers=None, verbose=True)

Train a multi-class object detection model using COCO-format annotations.

Supports multiple torchvision detection architectures including Faster R-CNN, RetinaNet, FCOS, and Mask R-CNN.

Parameters:

Name Type Description Default
images_dir str

Directory containing training images.

required
annotations_path str

Path to COCO-format annotations JSON file.

required
output_dir str

Directory for model outputs.

required
model_name str

Detection model architecture. One of "fasterrcnn_resnet50_fpn_v2" (default), "fasterrcnn_mobilenet_v3_large_fpn", "retinanet_resnet50_fpn_v2", "fcos_resnet50_fpn", or "maskrcnn_resnet50_fpn".

'fasterrcnn_resnet50_fpn_v2'
class_names list

List of class names including background. If None, extracted from annotations.

None
num_channels int

Number of image channels. Defaults to 3.

3
batch_size int

Training batch size. Defaults to 4.

4
num_epochs int

Number of training epochs. Defaults to 50.

50
learning_rate float

Initial learning rate. Defaults to 0.005.

0.005
val_split float

Validation split fraction. Defaults to 0.2.

0.2
seed int

Random seed. Defaults to 42.

42
pretrained bool

Whether to use pretrained backbone. Defaults to True.

True
pretrained_model_path str

Path to pretrained model.

None
device device

Compute device.

None
num_workers int

Number of data loading workers.

None
verbose bool

Whether to print progress. Defaults to True.

True

Returns:

Name Type Description
str str

Path to the best model checkpoint.

Source code in geoai/object_detect.py
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
def train_multiclass_detector(
    images_dir: str,
    annotations_path: str,
    output_dir: str,
    model_name: str = "fasterrcnn_resnet50_fpn_v2",
    class_names: Optional[List[str]] = None,
    num_channels: int = 3,
    batch_size: int = 4,
    num_epochs: int = 50,
    learning_rate: float = 0.005,
    val_split: float = 0.2,
    seed: int = 42,
    pretrained: bool = True,
    pretrained_model_path: Optional[str] = None,
    device: Optional[torch.device] = None,
    num_workers: Optional[int] = None,
    verbose: bool = True,
) -> str:
    """Train a multi-class object detection model using COCO-format annotations.

    Supports multiple torchvision detection architectures including
    Faster R-CNN, RetinaNet, FCOS, and Mask R-CNN.

    Args:
        images_dir (str): Directory containing training images.
        annotations_path (str): Path to COCO-format annotations JSON file.
        output_dir (str): Directory for model outputs.
        model_name (str): Detection model architecture. One of
            ``"fasterrcnn_resnet50_fpn_v2"`` (default),
            ``"fasterrcnn_mobilenet_v3_large_fpn"``,
            ``"retinanet_resnet50_fpn_v2"``,
            ``"fcos_resnet50_fpn"``, or
            ``"maskrcnn_resnet50_fpn"``.
        class_names (list, optional): List of class names including background.
            If None, extracted from annotations.
        num_channels (int): Number of image channels. Defaults to 3.
        batch_size (int): Training batch size. Defaults to 4.
        num_epochs (int): Number of training epochs. Defaults to 50.
        learning_rate (float): Initial learning rate. Defaults to 0.005.
        val_split (float): Validation split fraction. Defaults to 0.2.
        seed (int): Random seed. Defaults to 42.
        pretrained (bool): Whether to use pretrained backbone. Defaults to True.
        pretrained_model_path (str, optional): Path to pretrained model.
        device (torch.device, optional): Compute device.
        num_workers (int, optional): Number of data loading workers.
        verbose (bool): Whether to print progress. Defaults to True.

    Returns:
        str: Path to the best model checkpoint.
    """
    # Determine num_classes from annotations
    with open(annotations_path, "r") as f:
        coco_data = json.load(f)

    categories = coco_data.get("categories", [])
    num_classes = len(categories) + 1  # +1 for background

    if class_names is None:
        class_names = ["background"] + [
            cat["name"] for cat in sorted(categories, key=lambda c: c["id"])
        ]

    if verbose:
        logger.info(f"Training {model_name} with {num_classes} classes")
        logger.info(f"  Classes: {class_names[1:]}")

    # Save class names and model info to output directory for later use
    os.makedirs(os.path.abspath(output_dir), exist_ok=True)
    class_info = {
        "class_names": class_names,
        "num_classes": num_classes,
        "model_name": model_name,
    }
    with open(os.path.join(output_dir, "class_info.json"), "w") as f:
        json.dump(class_info, f, indent=2)

    train_MaskRCNN_model(
        images_dir=images_dir,
        labels_dir=annotations_path,
        output_dir=output_dir,
        input_format="coco_detection",
        num_channels=num_channels,
        num_classes=num_classes,
        pretrained=pretrained,
        pretrained_model_path=pretrained_model_path,
        batch_size=batch_size,
        num_epochs=num_epochs,
        learning_rate=learning_rate,
        val_split=val_split,
        seed=seed,
        device=device,
        num_workers=num_workers,
        verbose=verbose,
        model_name=model_name,
    )

    return os.path.join(output_dir, "best_model.pth")

visualize_coco_annotations(annotations_path, images_dir, num_samples=4, random=False, seed=None, figsize=(14, 14), cols=2, output_path=None)

Visualize sample images with their COCO-format bounding box annotations.

Loads a COCO JSON annotation file and displays a grid of sample images with colored bounding boxes and class labels overlaid.

Parameters:

Name Type Description Default
annotations_path str

Path to COCO-format annotations JSON file.

required
images_dir str

Directory containing the images referenced in the annotations file.

required
num_samples int

Number of sample images to display. Defaults to 4.

4
random bool

Whether to select images randomly instead of taking the first num_samples. Defaults to False.

False
seed int

Random seed for reproducibility when random=True.

None
figsize tuple

Figure size (width, height). Defaults to (14, 14).

(14, 14)
cols int

Number of columns in the grid layout. Defaults to 2.

2
output_path str

Path to save the figure. If None, displays interactively.

None
Source code in geoai/object_detect.py
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
def visualize_coco_annotations(
    annotations_path: str,
    images_dir: str,
    num_samples: int = 4,
    random: bool = False,
    seed: Optional[int] = None,
    figsize: Tuple[int, int] = (14, 14),
    cols: int = 2,
    output_path: Optional[str] = None,
) -> None:
    """Visualize sample images with their COCO-format bounding box annotations.

    Loads a COCO JSON annotation file and displays a grid of sample images
    with colored bounding boxes and class labels overlaid.

    Args:
        annotations_path (str): Path to COCO-format annotations JSON file.
        images_dir (str): Directory containing the images referenced in the
            annotations file.
        num_samples (int): Number of sample images to display. Defaults to 4.
        random (bool): Whether to select images randomly instead of taking
            the first ``num_samples``. Defaults to False.
        seed (int, optional): Random seed for reproducibility when
            ``random=True``.
        figsize (tuple): Figure size (width, height). Defaults to (14, 14).
        cols (int): Number of columns in the grid layout. Defaults to 2.
        output_path (str, optional): Path to save the figure. If None,
            displays interactively.
    """
    import random as random_module

    from PIL import Image as PILImage

    with open(annotations_path, "r") as f:
        coco_data = json.load(f)

    all_images = coco_data["images"]
    if random:
        rng = random_module.Random(seed)
        sample_images = rng.sample(all_images, min(num_samples, len(all_images)))
    else:
        sample_images = all_images[:num_samples]
    categories = {cat["id"]: cat["name"] for cat in coco_data["categories"]}
    cmap = plt.cm.get_cmap("tab10", 10)

    rows = math.ceil(num_samples / cols)
    fig, axes = plt.subplots(rows, cols, figsize=figsize)
    if num_samples == 1:
        axes = [axes]
    else:
        axes = axes.flatten()

    for ax_idx, img_info in enumerate(sample_images):
        img_path = os.path.join(images_dir, img_info["file_name"])
        img = PILImage.open(img_path)
        axes[ax_idx].imshow(img)
        axes[ax_idx].set_title(img_info["file_name"], fontsize=10)
        axes[ax_idx].axis("off")

        img_anns = [
            ann for ann in coco_data["annotations"] if ann["image_id"] == img_info["id"]
        ]
        for ann in img_anns:
            x, y, w, h = ann["bbox"]
            cat_id = ann["category_id"]
            color = cmap(cat_id % 10)
            rect = plt.Rectangle(
                (x, y), w, h, linewidth=2, edgecolor=color, facecolor="none"
            )
            axes[ax_idx].add_patch(rect)
            axes[ax_idx].text(
                x,
                y - 3,
                categories.get(cat_id, str(cat_id)),
                color="white",
                fontsize=7,
                bbox=dict(boxstyle="round,pad=0.2", facecolor=color, alpha=0.7),
            )

    # Hide unused axes
    for ax_idx in range(num_samples, len(axes)):
        axes[ax_idx].axis("off")

    plt.tight_layout()

    if output_path:
        plt.savefig(output_path, dpi=150, bbox_inches="tight")
        plt.close()
    else:
        plt.show()

visualize_multiclass_detections(image_path, detections, class_names=None, confidence_threshold=0.0, figsize=(15, 10), output_path=None, max_detections=200)

Visualize multi-class detections overlaid on an image.

Draws colored bounding boxes with class labels and confidence scores.

Parameters:

Name Type Description Default
image_path str

Path to the source image.

required
detections list

List of detection dicts with mask, score, box, label.

required
class_names list

List of class names (index 0 = background).

None
confidence_threshold float

Minimum score to display. Defaults to 0.0.

0.0
figsize tuple

Figure size (width, height). Defaults to (15, 10).

(15, 10)
output_path str

Path to save the figure. If None, displays.

None
max_detections int

Maximum detections to display. Defaults to 200.

200
Source code in geoai/object_detect.py
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
def visualize_multiclass_detections(
    image_path: str,
    detections: List[Dict],
    class_names: Optional[List[str]] = None,
    confidence_threshold: float = 0.0,
    figsize: Tuple[int, int] = (15, 10),
    output_path: Optional[str] = None,
    max_detections: int = 200,
) -> None:
    """Visualize multi-class detections overlaid on an image.

    Draws colored bounding boxes with class labels and confidence scores.

    Args:
        image_path (str): Path to the source image.
        detections (list): List of detection dicts with mask, score, box, label.
        class_names (list, optional): List of class names (index 0 = background).
        confidence_threshold (float): Minimum score to display. Defaults to 0.0.
        figsize (tuple): Figure size (width, height). Defaults to (15, 10).
        output_path (str, optional): Path to save the figure. If None, displays.
        max_detections (int): Maximum detections to display. Defaults to 200.
    """
    from PIL import Image as PILImage

    # Load image
    if image_path.lower().endswith((".tif", ".tiff")):
        import rasterio

        with rasterio.open(image_path) as src:
            image = src.read()
            if image.shape[0] >= 3:
                image = image[:3].transpose(1, 2, 0)
            else:
                image = image[0]
    else:
        image = np.array(PILImage.open(image_path).convert("RGB"))

    # Normalize for display
    if image.dtype != np.uint8:
        if image.max() <= 1.0:
            image = (image * 255).astype(np.uint8)
        else:
            image = image.astype(np.uint8)

    # Color map for classes
    cmap = plt.cm.get_cmap("tab20", 20)

    fig, ax = plt.subplots(1, 1, figsize=figsize)
    ax.imshow(image)

    # Filter and sort detections
    filtered = [d for d in detections if d["score"] >= confidence_threshold]
    filtered.sort(key=lambda x: x["score"], reverse=True)
    filtered = filtered[:max_detections]

    legend_entries = {}

    for det in filtered:
        box = det["box"]
        label = det["label"]
        score = det["score"]

        color = cmap(label % 20)[:3]

        # Draw bounding box
        rect = plt.Rectangle(
            (box[0], box[1]),
            box[2] - box[0],
            box[3] - box[1],
            linewidth=2,
            edgecolor=color,
            facecolor="none",
        )
        ax.add_patch(rect)

        # Class name
        name = f"class_{label}"
        if class_names and label < len(class_names):
            name = class_names[label]

        ax.text(
            box[0],
            box[1] - 5,
            f"{name}: {score:.2f}",
            color="white",
            fontsize=8,
            fontweight="bold",
            bbox=dict(boxstyle="round,pad=0.2", facecolor=color, alpha=0.7),
        )

        if name not in legend_entries:
            legend_entries[name] = color

    # Add legend
    if legend_entries:
        from matplotlib.patches import Patch

        handles = [
            Patch(facecolor=c, edgecolor=c, label=n)
            for n, c in sorted(legend_entries.items())
        ]
        ax.legend(handles=handles, loc="upper right", fontsize=9)

    ax.set_title(f"Detections: {len(filtered)}")
    ax.axis("off")
    plt.tight_layout()

    if output_path:
        plt.savefig(output_path, dpi=150, bbox_inches="tight")
        plt.close()
    else:
        plt.show()