Skip to content

object_detect module

High-level functions for multi-class object detection.

This module provides convenience functions for training, evaluating, and running inference with Mask R-CNN models on COCO-format datasets, including support for the NWPU-VHR-10 remote sensing benchmark.

detections_to_geodataframe(detections, geotiff_path, class_names=None)

Convert detections to a GeoDataFrame with geospatial coordinates.

Converts pixel-space bounding boxes to geospatial coordinates using the CRS and transform from the source GeoTIFF.

Parameters:

Name Type Description Default
detections list

List of detection dicts, each with keys: mask, score, box (in pixel coords), label.

required
geotiff_path str

Path to the source GeoTIFF (for CRS and transform).

required
class_names list

List of class names (index 0 = background).

None

Returns:

Type Description
Any

geopandas.GeoDataFrame: GeoDataFrame with columns: geometry, class_id,

Any

class_name, score, area_pixels.

Source code in geoai/object_detect.py
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
def detections_to_geodataframe(
    detections: List[Dict],
    geotiff_path: str,
    class_names: Optional[List[str]] = None,
) -> Any:
    """Convert detections to a GeoDataFrame with geospatial coordinates.

    Converts pixel-space bounding boxes to geospatial coordinates using the
    CRS and transform from the source GeoTIFF.

    Args:
        detections (list): List of detection dicts, each with keys:
            mask, score, box (in pixel coords), label.
        geotiff_path (str): Path to the source GeoTIFF (for CRS and transform).
        class_names (list, optional): List of class names (index 0 = background).

    Returns:
        geopandas.GeoDataFrame: GeoDataFrame with columns: geometry, class_id,
        class_name, score, area_pixels.
    """
    import geopandas as gpd
    import rasterio
    from shapely.geometry import box as shapely_box

    if len(detections) == 0:
        return gpd.GeoDataFrame(
            columns=["geometry", "class_id", "class_name", "score", "area_pixels"]
        )

    with rasterio.open(geotiff_path) as src:
        transform = src.transform
        crs = src.crs

    records = []
    for det in detections:
        bx = det["box"]
        label = det["label"]
        score = det["score"]

        # Convert pixel coordinates to geographic coordinates
        x_min, y_max = transform * (bx[0], bx[1])
        x_max, y_min = transform * (bx[2], bx[3])

        geom = shapely_box(x_min, y_min, x_max, y_max)
        area_pixels = det["mask"].sum() if "mask" in det else 0

        name = "unknown"
        if class_names and label < len(class_names):
            name = class_names[label]
        elif class_names:
            name = f"class_{label}"

        records.append(
            {
                "geometry": geom,
                "class_id": label,
                "class_name": name,
                "score": score,
                "area_pixels": int(area_pixels),
            }
        )

    gdf = gpd.GeoDataFrame(records, crs=crs)
    return gdf

download_nwpu_vhr10(output_dir='NWPU-VHR-10', overwrite=False)

Download and extract the NWPU-VHR-10 dataset.

The NWPU-VHR-10 dataset contains 800 VHR (Very High Resolution) remote sensing images with 10 object classes: airplane, ship, storage_tank, baseball_diamond, tennis_court, basketball_court, ground_track_field, harbor, bridge, and vehicle. It has 3,775 annotated instances in COCO format (bounding boxes and instance segmentation masks).

Parameters:

Name Type Description Default
output_dir str

Path for the downloaded ZIP file and extracted dataset directory. Defaults to "NWPU-VHR-10".

'NWPU-VHR-10'
overwrite bool

Whether to overwrite existing files. Defaults to False.

False

Returns:

Name Type Description
str str

Path to the extracted dataset directory.

Source code in geoai/object_detect.py
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
def download_nwpu_vhr10(
    output_dir: str = "NWPU-VHR-10",
    overwrite: bool = False,
) -> str:
    """Download and extract the NWPU-VHR-10 dataset.

    The NWPU-VHR-10 dataset contains 800 VHR (Very High Resolution) remote
    sensing images with 10 object classes: airplane, ship, storage_tank,
    baseball_diamond, tennis_court, basketball_court, ground_track_field,
    harbor, bridge, and vehicle. It has 3,775 annotated instances in COCO
    format (bounding boxes and instance segmentation masks).

    Args:
        output_dir (str): Path for the downloaded ZIP file and extracted
            dataset directory. Defaults to "NWPU-VHR-10".
        overwrite (bool): Whether to overwrite existing files. Defaults to False.

    Returns:
        str: Path to the extracted dataset directory.
    """
    zip_path = output_dir + ".zip"
    data_path = download_file(NWPU_VHR10_URL, output_path=zip_path, overwrite=overwrite)
    return data_path

download_nwpu_vhr10_model(repo_id=NWPU_VHR10_HF_REPO, filename=NWPU_VHR10_HF_FILENAME)

Download the pretrained NWPU-VHR-10 Mask R-CNN model from HuggingFace Hub.

Downloads a Mask R-CNN (ResNet-50 FPN) model trained on the NWPU-VHR-10 dataset for 10-class object detection on remote sensing imagery.

The model achieves the following performance on the validation set
  • mAP@0.5: 0.709
  • mAP@0.75: 0.518
  • mAP@[0.5:0.95]: 0.459

Parameters:

Name Type Description Default
repo_id str

HuggingFace Hub repository ID. Defaults to "giswqs/nwpu-vhr10-maskrcnn".

NWPU_VHR10_HF_REPO
filename str

Model filename in the repository. Defaults to "best_model.pth".

NWPU_VHR10_HF_FILENAME

Returns:

Name Type Description
str str

Local path to the downloaded model weights file.

Example

import geoai model_path = geoai.download_nwpu_vhr10_model() print(model_path) # local cache path

Source code in geoai/object_detect.py
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
def download_nwpu_vhr10_model(
    repo_id: str = NWPU_VHR10_HF_REPO,
    filename: str = NWPU_VHR10_HF_FILENAME,
) -> str:
    """Download the pretrained NWPU-VHR-10 Mask R-CNN model from HuggingFace Hub.

    Downloads a Mask R-CNN (ResNet-50 FPN) model trained on the NWPU-VHR-10
    dataset for 10-class object detection on remote sensing imagery.

    The model achieves the following performance on the validation set:
        - mAP@0.5: 0.709
        - mAP@0.75: 0.518
        - mAP@[0.5:0.95]: 0.459

    Args:
        repo_id (str): HuggingFace Hub repository ID.
            Defaults to "giswqs/nwpu-vhr10-maskrcnn".
        filename (str): Model filename in the repository.
            Defaults to "best_model.pth".

    Returns:
        str: Local path to the downloaded model weights file.

    Example:
        >>> import geoai
        >>> model_path = geoai.download_nwpu_vhr10_model()
        >>> print(model_path)  # local cache path
    """
    from huggingface_hub import hf_hub_download

    model_path = hf_hub_download(repo_id=repo_id, filename=filename)
    print(f"Model downloaded to: {model_path}")
    return model_path

evaluate_multiclass_detector(model_path=None, images_dir='', annotations_path='', num_classes=11, class_names=None, num_channels=3, batch_size=4, device=None, num_workers=None, repo_id=None, verbose=True)

Evaluate a trained multi-class detection model on a dataset.

Loads a trained model and computes COCO-style mAP metrics on the provided dataset.

If model_path is None, the pretrained NWPU-VHR-10 model is automatically downloaded from HuggingFace Hub.

Parameters:

Name Type Description Default
model_path str

Path to trained model weights. If None, downloads the pretrained NWPU-VHR-10 model.

None
images_dir str

Directory containing evaluation images.

''
annotations_path str

Path to COCO-format annotations JSON.

''
num_classes int

Number of classes including background. Defaults to 11 (NWPU-VHR-10).

11
class_names list

List of class names (excluding background).

None
num_channels int

Number of image channels. Defaults to 3.

3
batch_size int

Evaluation batch size. Defaults to 4.

4
device device

Compute device.

None
num_workers int

Number of data loading workers.

None
repo_id str

HuggingFace Hub repository ID for downloading the model. Defaults to "giswqs/nwpu-vhr10-maskrcnn".

None
verbose bool

Whether to print results. Defaults to True.

True

Returns:

Type Description
Dict[str, float]

Dict with mAP metrics.

Source code in geoai/object_detect.py
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
def evaluate_multiclass_detector(
    model_path: Optional[str] = None,
    images_dir: str = "",
    annotations_path: str = "",
    num_classes: int = 11,
    class_names: Optional[List[str]] = None,
    num_channels: int = 3,
    batch_size: int = 4,
    device: Optional[torch.device] = None,
    num_workers: Optional[int] = None,
    repo_id: Optional[str] = None,
    verbose: bool = True,
) -> Dict[str, float]:
    """Evaluate a trained multi-class detection model on a dataset.

    Loads a trained model and computes COCO-style mAP metrics on the
    provided dataset.

    If ``model_path`` is None, the pretrained NWPU-VHR-10 model is
    automatically downloaded from HuggingFace Hub.

    Args:
        model_path (str, optional): Path to trained model weights.
            If None, downloads the pretrained NWPU-VHR-10 model.
        images_dir (str): Directory containing evaluation images.
        annotations_path (str): Path to COCO-format annotations JSON.
        num_classes (int): Number of classes including background.
            Defaults to 11 (NWPU-VHR-10).
        class_names (list, optional): List of class names (excluding background).
        num_channels (int): Number of image channels. Defaults to 3.
        batch_size (int): Evaluation batch size. Defaults to 4.
        device (torch.device, optional): Compute device.
        num_workers (int, optional): Number of data loading workers.
        repo_id (str, optional): HuggingFace Hub repository ID for
            downloading the model. Defaults to "giswqs/nwpu-vhr10-maskrcnn".
        verbose (bool): Whether to print results. Defaults to True.

    Returns:
        Dict with mAP metrics.
    """
    import platform

    from torch.utils.data import DataLoader

    if device is None:
        device = get_device()

    # Handle pretrained model download
    if model_path is None:
        hf_repo = repo_id or NWPU_VHR10_HF_REPO
        model_path = download_nwpu_vhr10_model(
            repo_id=hf_repo, filename=NWPU_VHR10_HF_FILENAME
        )
        num_classes = len(NWPU_VHR10_CLASSES)
        if class_names is None:
            class_names = NWPU_VHR10_CLASSES[1:]  # Exclude background

    # Load model
    model = get_instance_segmentation_model(
        num_classes=num_classes, num_channels=num_channels, pretrained=False
    )

    if not os.path.exists(model_path):
        hf_repo = repo_id or NWPU_VHR10_HF_REPO
        from huggingface_hub import hf_hub_download

        model_path = hf_hub_download(repo_id=hf_repo, filename=model_path)

    state_dict = torch.load(model_path, map_location=device)
    if any(key.startswith("module.") for key in state_dict.keys()):
        state_dict = {
            key.replace("module.", ""): value for key, value in state_dict.items()
        }
    model.load_state_dict(state_dict)
    model.to(device)

    # Create dataset and loader
    dataset = COCODetectionDataset(
        coco_json_path=annotations_path,
        images_dir=images_dir,
        transforms=get_transform(train=False),
        num_channels=num_channels,
    )

    if num_workers is None:
        num_workers = 0 if platform.system() in ["Darwin", "Windows"] else 4

    data_loader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=False,
        collate_fn=collate_fn,
        num_workers=num_workers,
    )

    # Evaluate
    results = evaluate_coco_metrics(
        model=model,
        data_loader=data_loader,
        device=device,
        class_names=class_names,
        verbose=verbose,
    )

    return results

multiclass_detection(input_path, output_path, model_path=None, num_classes=11, class_names=None, window_size=512, overlap=256, confidence_threshold=0.5, nms_threshold=0.3, batch_size=4, num_channels=3, device=None, repo_id=None, **kwargs)

Perform multi-class object detection on a GeoTIFF or image.

Loads a trained Mask R-CNN model and runs inference using a sliding window approach. Outputs a 2-band raster with class labels and instance IDs.

If model_path is None, the pretrained NWPU-VHR-10 model is automatically downloaded from HuggingFace Hub with default num_classes=11 and class_names set to NWPU-VHR-10 classes.

Parameters:

Name Type Description Default
input_path str

Path to input image (GeoTIFF, JPEG, PNG, etc.).

required
output_path str

Path to save output raster.

required
model_path str

Path to trained model weights (.pth file). If None, downloads the pretrained NWPU-VHR-10 model from HuggingFace Hub. If the path does not exist locally, it is treated as a filename to download from repo_id.

None
num_classes int

Number of classes including background. Defaults to 11 (NWPU-VHR-10).

11
class_names list

List of class names (index 0 = background). If None and using the pretrained model, defaults to NWPU-VHR-10 class names.

None
window_size int

Sliding window size. Defaults to 512.

512
overlap int

Window overlap in pixels. Defaults to 256.

256
confidence_threshold float

Minimum detection score. Defaults to 0.5.

0.5
nms_threshold float

IoU threshold for NMS. Defaults to 0.3.

0.3
batch_size int

Inference batch size. Defaults to 4.

4
num_channels int

Number of input image channels. Defaults to 3.

3
device device

Compute device.

None
repo_id str

HuggingFace Hub repository ID for downloading the model. Defaults to "giswqs/nwpu-vhr10-maskrcnn".

None
**kwargs Any

Additional keyword arguments.

{}

Returns:

Type Description
str

Tuple of (output_path, inference_time, detections_list) where each

float

detection is a dict with mask, score, box, and label.

Example

import geoai

Use pretrained model (auto-downloads from HuggingFace)

result_path, time, dets = geoai.multiclass_detection( ... input_path="image.tif", ... output_path="detections.tif", ... )

Use a custom trained model

result_path, time, dets = geoai.multiclass_detection( ... input_path="image.tif", ... output_path="detections.tif", ... model_path="my_model.pth", ... num_classes=5, ... )

Source code in geoai/object_detect.py
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
def multiclass_detection(
    input_path: str,
    output_path: str,
    model_path: Optional[str] = None,
    num_classes: int = 11,
    class_names: Optional[List[str]] = None,
    window_size: int = 512,
    overlap: int = 256,
    confidence_threshold: float = 0.5,
    nms_threshold: float = 0.3,
    batch_size: int = 4,
    num_channels: int = 3,
    device: Optional[torch.device] = None,
    repo_id: Optional[str] = None,
    **kwargs: Any,
) -> Tuple[str, float, List[Dict]]:
    """Perform multi-class object detection on a GeoTIFF or image.

    Loads a trained Mask R-CNN model and runs inference using a sliding
    window approach. Outputs a 2-band raster with class labels and
    instance IDs.

    If ``model_path`` is None, the pretrained NWPU-VHR-10 model is
    automatically downloaded from HuggingFace Hub with default
    ``num_classes=11`` and ``class_names`` set to NWPU-VHR-10 classes.

    Args:
        input_path (str): Path to input image (GeoTIFF, JPEG, PNG, etc.).
        output_path (str): Path to save output raster.
        model_path (str, optional): Path to trained model weights (.pth file).
            If None, downloads the pretrained NWPU-VHR-10 model from
            HuggingFace Hub. If the path does not exist locally, it is
            treated as a filename to download from ``repo_id``.
        num_classes (int): Number of classes including background. Defaults
            to 11 (NWPU-VHR-10).
        class_names (list, optional): List of class names (index 0 = background).
            If None and using the pretrained model, defaults to NWPU-VHR-10
            class names.
        window_size (int): Sliding window size. Defaults to 512.
        overlap (int): Window overlap in pixels. Defaults to 256.
        confidence_threshold (float): Minimum detection score. Defaults to 0.5.
        nms_threshold (float): IoU threshold for NMS. Defaults to 0.3.
        batch_size (int): Inference batch size. Defaults to 4.
        num_channels (int): Number of input image channels. Defaults to 3.
        device (torch.device, optional): Compute device.
        repo_id (str, optional): HuggingFace Hub repository ID for
            downloading the model. Defaults to "giswqs/nwpu-vhr10-maskrcnn".
        **kwargs: Additional keyword arguments.

    Returns:
        Tuple of (output_path, inference_time, detections_list) where each
        detection is a dict with mask, score, box, and label.

    Example:
        >>> import geoai
        >>> # Use pretrained model (auto-downloads from HuggingFace)
        >>> result_path, time, dets = geoai.multiclass_detection(
        ...     input_path="image.tif",
        ...     output_path="detections.tif",
        ... )
        >>> # Use a custom trained model
        >>> result_path, time, dets = geoai.multiclass_detection(
        ...     input_path="image.tif",
        ...     output_path="detections.tif",
        ...     model_path="my_model.pth",
        ...     num_classes=5,
        ... )
    """
    import rasterio

    if device is None:
        device = get_device()

    # Handle pretrained model download
    use_pretrained = model_path is None
    if use_pretrained:
        hf_repo = repo_id or NWPU_VHR10_HF_REPO
        model_path = download_nwpu_vhr10_model(
            repo_id=hf_repo, filename=NWPU_VHR10_HF_FILENAME
        )
        if class_names is None:
            class_names = NWPU_VHR10_CLASSES
        num_classes = len(NWPU_VHR10_CLASSES)

    # Convert non-GeoTIFF images to temporary GeoTIFF for processing
    temp_tif = None
    if not input_path.lower().endswith((".tif", ".tiff")):
        from PIL import Image as PILImage

        img = PILImage.open(input_path).convert("RGB")
        img_array = np.array(img)
        h, w = img_array.shape[:2]

        temp_tif = output_path.replace(
            os.path.splitext(output_path)[1], "_temp_input.tif"
        )
        profile = {
            "driver": "GTiff",
            "dtype": "uint8",
            "width": w,
            "height": h,
            "count": img_array.shape[2] if img_array.ndim == 3 else 1,
            "crs": None,
            "transform": rasterio.transform.from_bounds(0, 0, w, h, w, h),
        }
        with rasterio.open(temp_tif, "w", **profile) as dst:
            if img_array.ndim == 3:
                for band in range(img_array.shape[2]):
                    dst.write(img_array[:, :, band], band + 1)
            else:
                dst.write(img_array, 1)
        input_path = temp_tif

    # Resolve model path (download if needed)
    if not os.path.exists(model_path):
        hf_repo = repo_id or NWPU_VHR10_HF_REPO
        from huggingface_hub import hf_hub_download

        model_path = hf_hub_download(repo_id=hf_repo, filename=model_path)

    # Try to load class_info.json sidecar for num_classes / class_names
    class_info_path = os.path.join(os.path.dirname(model_path), "class_info.json")
    if not use_pretrained and os.path.exists(class_info_path):
        with open(class_info_path, "r") as f:
            class_info = json.load(f)
        num_classes = class_info.get("num_classes", num_classes)
        if class_names is None:
            class_names = class_info.get("class_names", class_names)

    # Infer num_classes from checkpoint if still using default
    state_dict = torch.load(model_path, map_location=device)
    if any(key.startswith("module.") for key in state_dict.keys()):
        state_dict = {
            key.replace("module.", ""): value for key, value in state_dict.items()
        }

    # The box predictor's cls_score weight shape is [num_classes, hidden_dim]
    cls_key = "roi_heads.box_predictor.cls_score.weight"
    if not use_pretrained and cls_key in state_dict:
        inferred = state_dict[cls_key].shape[0]
        if inferred != num_classes:
            num_classes = inferred

    # Load model
    model = get_instance_segmentation_model(
        num_classes=num_classes, num_channels=num_channels, pretrained=False
    )
    model.load_state_dict(state_dict)

    result = multiclass_detection_inference_on_geotiff(
        model=model,
        geotiff_path=input_path,
        output_path=output_path,
        class_names=class_names,
        window_size=window_size,
        overlap=overlap,
        confidence_threshold=confidence_threshold,
        nms_threshold=nms_threshold,
        batch_size=batch_size,
        num_channels=num_channels,
        device=device,
        **kwargs,
    )

    # Clean up temporary file
    if temp_tif and os.path.exists(temp_tif):
        os.remove(temp_tif)

    return result

prepare_nwpu_vhr10(data_dir, output_dir=None, val_split=0.2, seed=42)

Prepare NWPU-VHR-10 dataset for training.

Converts the original text-based annotations to COCO JSON format, then splits the dataset into train/val sets. The original dataset uses text files with (x1,y1),(x2,y2),class_id per line for bounding boxes.

Note: Only images with at least one annotation are included in the train/val splits. The 150 "negative" images in the NWPU-VHR-10 dataset (those without any target objects) are excluded from the splits.

Parameters:

Name Type Description Default
data_dir str

Path to the extracted NWPU-VHR-10 directory.

required
output_dir str

Output directory for organized data. If None, creates files alongside the original data.

None
val_split float

Fraction of data for validation. Defaults to 0.2.

0.2
seed int

Random seed for reproducibility. Defaults to 42.

42

Returns:

Type Description
Dict[str, Any]

Dict with keys: - 'images_dir': Path to images directory - 'annotations_path': Path to the full annotations JSON - 'train_annotations': Path to train split annotations JSON - 'val_annotations': Path to val split annotations JSON - 'train_image_ids': List of training image IDs - 'val_image_ids': List of validation image IDs - 'class_names': List of class names (including background) - 'num_classes': Number of classes (including background)

Source code in geoai/object_detect.py
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
def prepare_nwpu_vhr10(
    data_dir: str,
    output_dir: Optional[str] = None,
    val_split: float = 0.2,
    seed: int = 42,
) -> Dict[str, Any]:
    """Prepare NWPU-VHR-10 dataset for training.

    Converts the original text-based annotations to COCO JSON format,
    then splits the dataset into train/val sets. The original dataset uses
    text files with ``(x1,y1),(x2,y2),class_id`` per line for bounding boxes.

    Note: Only images with at least one annotation are included in the
    train/val splits. The 150 "negative" images in the NWPU-VHR-10 dataset
    (those without any target objects) are excluded from the splits.

    Args:
        data_dir (str): Path to the extracted NWPU-VHR-10 directory.
        output_dir (str, optional): Output directory for organized data.
            If None, creates files alongside the original data.
        val_split (float): Fraction of data for validation. Defaults to 0.2.
        seed (int): Random seed for reproducibility. Defaults to 42.

    Returns:
        Dict with keys:
            - 'images_dir': Path to images directory
            - 'annotations_path': Path to the full annotations JSON
            - 'train_annotations': Path to train split annotations JSON
            - 'val_annotations': Path to val split annotations JSON
            - 'train_image_ids': List of training image IDs
            - 'val_image_ids': List of validation image IDs
            - 'class_names': List of class names (including background)
            - 'num_classes': Number of classes (including background)
    """
    from sklearn.model_selection import train_test_split

    if output_dir is None:
        output_dir = data_dir

    os.makedirs(output_dir, exist_ok=True)

    # Handle nested directory structure (NWPU-VHR-10/NWPU-VHR-10/)
    actual_dir = data_dir
    nested = os.path.join(data_dir, "NWPU-VHR-10")
    if os.path.isdir(nested):
        actual_dir = nested

    # Find images directory
    images_dir = None
    for candidate in ["positive image set", "positive_image_set", "images"]:
        path = os.path.join(actual_dir, candidate)
        if os.path.isdir(path):
            images_dir = path
            break

    if images_dir is None:
        raise FileNotFoundError(
            f"Could not find images directory in {actual_dir}. "
            "Expected 'positive image set' directory."
        )

    # Find ground truth directory
    gt_dir = None
    for candidate in ["ground truth", "ground_truth", "annotations", "labels"]:
        path = os.path.join(actual_dir, candidate)
        if os.path.isdir(path):
            gt_dir = path
            break

    # Check if COCO JSON already exists
    annotations_path = os.path.join(output_dir, "annotations.json")
    if os.path.exists(annotations_path):
        print(f"Using existing COCO annotations: {annotations_path}")
    elif gt_dir is not None:
        # Convert text annotations to COCO JSON
        print("Converting NWPU-VHR-10 text annotations to COCO JSON format...")
        _convert_nwpu_to_coco(images_dir, gt_dir, annotations_path)
        print(f"COCO annotations saved to: {annotations_path}")
    else:
        # Look for existing JSON annotations
        for candidate_ann in [
            "annotations.json",
            "instances.json",
        ]:
            path = os.path.join(actual_dir, candidate_ann)
            if os.path.isfile(path):
                annotations_path = path
                break

    if not os.path.exists(annotations_path):
        raise FileNotFoundError(f"Could not find or create annotations for {data_dir}.")

    # Load annotations
    with open(annotations_path, "r") as f:
        coco_data = json.load(f)

    # Get image IDs that have annotations
    annotated_image_ids = set()
    for ann in coco_data.get("annotations", []):
        annotated_image_ids.add(ann["image_id"])

    image_ids = sorted(list(annotated_image_ids))

    # Split into train and validation
    train_ids, val_ids = train_test_split(
        image_ids, test_size=val_split, random_state=seed
    )

    # Create split annotation files
    train_ann_path = os.path.join(output_dir, "train_annotations.json")
    val_ann_path = os.path.join(output_dir, "val_annotations.json")

    train_ids_set = set(train_ids)
    val_ids_set = set(val_ids)

    all_images = {img["id"]: img for img in coco_data["images"]}

    train_coco = {
        "images": [all_images[img_id] for img_id in train_ids if img_id in all_images],
        "annotations": [
            ann for ann in coco_data["annotations"] if ann["image_id"] in train_ids_set
        ],
        "categories": coco_data.get("categories", []),
    }

    val_coco = {
        "images": [all_images[img_id] for img_id in val_ids if img_id in all_images],
        "annotations": [
            ann for ann in coco_data["annotations"] if ann["image_id"] in val_ids_set
        ],
        "categories": coco_data.get("categories", []),
    }

    with open(train_ann_path, "w") as f:
        json.dump(train_coco, f)

    with open(val_ann_path, "w") as f:
        json.dump(val_coco, f)

    class_names = NWPU_VHR10_CLASSES

    print(f"Dataset prepared:")
    print(f"  Images directory: {images_dir}")
    print(f"  Total annotated images: {len(image_ids)}")
    print(f"  Total annotations: {len(coco_data['annotations'])}")
    print(f"  Training images: {len(train_ids)}")
    print(f"  Validation images: {len(val_ids)}")
    print(f"  Classes: {class_names[1:]}")

    return {
        "images_dir": images_dir,
        "annotations_path": annotations_path,
        "train_annotations": train_ann_path,
        "val_annotations": val_ann_path,
        "train_image_ids": train_ids,
        "val_image_ids": val_ids,
        "class_names": class_names,
        "num_classes": len(class_names),
    }

train_multiclass_detector(images_dir, annotations_path, output_dir, class_names=None, num_channels=3, batch_size=4, num_epochs=50, learning_rate=0.005, val_split=0.2, seed=42, pretrained=True, pretrained_model_path=None, device=None, num_workers=None, verbose=True)

Train a multi-class object detection model using COCO-format annotations.

This is a convenience wrapper around train_MaskRCNN_model that automatically sets up the COCODetectionDataset with proper class mapping.

Parameters:

Name Type Description Default
images_dir str

Directory containing training images.

required
annotations_path str

Path to COCO-format annotations JSON file.

required
output_dir str

Directory for model outputs.

required
class_names list

List of class names including background. If None, extracted from annotations.

None
num_channels int

Number of image channels. Defaults to 3.

3
batch_size int

Training batch size. Defaults to 4.

4
num_epochs int

Number of training epochs. Defaults to 50.

50
learning_rate float

Initial learning rate. Defaults to 0.005.

0.005
val_split float

Validation split fraction. Defaults to 0.2.

0.2
seed int

Random seed. Defaults to 42.

42
pretrained bool

Whether to use pretrained backbone. Defaults to True.

True
pretrained_model_path str

Path to pretrained model.

None
device device

Compute device.

None
num_workers int

Number of data loading workers.

None
verbose bool

Whether to print progress. Defaults to True.

True

Returns:

Name Type Description
str str

Path to the best model checkpoint.

Source code in geoai/object_detect.py
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
def train_multiclass_detector(
    images_dir: str,
    annotations_path: str,
    output_dir: str,
    class_names: Optional[List[str]] = None,
    num_channels: int = 3,
    batch_size: int = 4,
    num_epochs: int = 50,
    learning_rate: float = 0.005,
    val_split: float = 0.2,
    seed: int = 42,
    pretrained: bool = True,
    pretrained_model_path: Optional[str] = None,
    device: Optional[torch.device] = None,
    num_workers: Optional[int] = None,
    verbose: bool = True,
) -> str:
    """Train a multi-class object detection model using COCO-format annotations.

    This is a convenience wrapper around train_MaskRCNN_model that
    automatically sets up the COCODetectionDataset with proper class mapping.

    Args:
        images_dir (str): Directory containing training images.
        annotations_path (str): Path to COCO-format annotations JSON file.
        output_dir (str): Directory for model outputs.
        class_names (list, optional): List of class names including background.
            If None, extracted from annotations.
        num_channels (int): Number of image channels. Defaults to 3.
        batch_size (int): Training batch size. Defaults to 4.
        num_epochs (int): Number of training epochs. Defaults to 50.
        learning_rate (float): Initial learning rate. Defaults to 0.005.
        val_split (float): Validation split fraction. Defaults to 0.2.
        seed (int): Random seed. Defaults to 42.
        pretrained (bool): Whether to use pretrained backbone. Defaults to True.
        pretrained_model_path (str, optional): Path to pretrained model.
        device (torch.device, optional): Compute device.
        num_workers (int, optional): Number of data loading workers.
        verbose (bool): Whether to print progress. Defaults to True.

    Returns:
        str: Path to the best model checkpoint.
    """
    # Determine num_classes from annotations
    with open(annotations_path, "r") as f:
        coco_data = json.load(f)

    categories = coco_data.get("categories", [])
    num_classes = len(categories) + 1  # +1 for background

    if class_names is None:
        class_names = ["background"] + [
            cat["name"] for cat in sorted(categories, key=lambda c: c["id"])
        ]

    if verbose:
        print(f"Training multi-class detector with {num_classes} classes")
        print(f"  Classes: {class_names[1:]}")

    # Save class names to output directory for later use
    os.makedirs(os.path.abspath(output_dir), exist_ok=True)
    class_info = {
        "class_names": class_names,
        "num_classes": num_classes,
    }
    with open(os.path.join(output_dir, "class_info.json"), "w") as f:
        json.dump(class_info, f, indent=2)

    train_MaskRCNN_model(
        images_dir=images_dir,
        labels_dir=annotations_path,
        output_dir=output_dir,
        input_format="coco_detection",
        num_channels=num_channels,
        num_classes=num_classes,
        pretrained=pretrained,
        pretrained_model_path=pretrained_model_path,
        batch_size=batch_size,
        num_epochs=num_epochs,
        learning_rate=learning_rate,
        val_split=val_split,
        seed=seed,
        device=device,
        num_workers=num_workers,
        verbose=verbose,
    )

    return os.path.join(output_dir, "best_model.pth")

visualize_multiclass_detections(image_path, detections, class_names=None, confidence_threshold=0.0, figsize=(15, 10), output_path=None, max_detections=200)

Visualize multi-class detections overlaid on an image.

Draws colored bounding boxes with class labels and confidence scores.

Parameters:

Name Type Description Default
image_path str

Path to the source image.

required
detections list

List of detection dicts with mask, score, box, label.

required
class_names list

List of class names (index 0 = background).

None
confidence_threshold float

Minimum score to display. Defaults to 0.0.

0.0
figsize tuple

Figure size (width, height). Defaults to (15, 10).

(15, 10)
output_path str

Path to save the figure. If None, displays.

None
max_detections int

Maximum detections to display. Defaults to 200.

200
Source code in geoai/object_detect.py
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
def visualize_multiclass_detections(
    image_path: str,
    detections: List[Dict],
    class_names: Optional[List[str]] = None,
    confidence_threshold: float = 0.0,
    figsize: Tuple[int, int] = (15, 10),
    output_path: Optional[str] = None,
    max_detections: int = 200,
) -> None:
    """Visualize multi-class detections overlaid on an image.

    Draws colored bounding boxes with class labels and confidence scores.

    Args:
        image_path (str): Path to the source image.
        detections (list): List of detection dicts with mask, score, box, label.
        class_names (list, optional): List of class names (index 0 = background).
        confidence_threshold (float): Minimum score to display. Defaults to 0.0.
        figsize (tuple): Figure size (width, height). Defaults to (15, 10).
        output_path (str, optional): Path to save the figure. If None, displays.
        max_detections (int): Maximum detections to display. Defaults to 200.
    """
    from PIL import Image as PILImage

    # Load image
    if image_path.lower().endswith((".tif", ".tiff")):
        import rasterio

        with rasterio.open(image_path) as src:
            image = src.read()
            if image.shape[0] >= 3:
                image = image[:3].transpose(1, 2, 0)
            else:
                image = image[0]
    else:
        image = np.array(PILImage.open(image_path).convert("RGB"))

    # Normalize for display
    if image.dtype != np.uint8:
        if image.max() <= 1.0:
            image = (image * 255).astype(np.uint8)
        else:
            image = image.astype(np.uint8)

    # Color map for classes
    cmap = plt.cm.get_cmap("tab20", 20)

    fig, ax = plt.subplots(1, 1, figsize=figsize)
    ax.imshow(image)

    # Filter and sort detections
    filtered = [d for d in detections if d["score"] >= confidence_threshold]
    filtered.sort(key=lambda x: x["score"], reverse=True)
    filtered = filtered[:max_detections]

    legend_entries = {}

    for det in filtered:
        box = det["box"]
        label = det["label"]
        score = det["score"]

        color = cmap(label % 20)[:3]

        # Draw bounding box
        rect = plt.Rectangle(
            (box[0], box[1]),
            box[2] - box[0],
            box[3] - box[1],
            linewidth=2,
            edgecolor=color,
            facecolor="none",
        )
        ax.add_patch(rect)

        # Class name
        name = f"class_{label}"
        if class_names and label < len(class_names):
            name = class_names[label]

        ax.text(
            box[0],
            box[1] - 5,
            f"{name}: {score:.2f}",
            color="white",
            fontsize=8,
            fontweight="bold",
            bbox=dict(boxstyle="round,pad=0.2", facecolor=color, alpha=0.7),
        )

        if name not in legend_entries:
            legend_entries[name] = color

    # Add legend
    if legend_entries:
        from matplotlib.patches import Patch

        handles = [
            Patch(facecolor=c, edgecolor=c, label=n)
            for n, c in sorted(legend_entries.items())
        ]
        ax.legend(handles=handles, loc="upper right", fontsize=9)

    ax.set_title(f"Detections: {len(filtered)}")
    ax.axis("off")
    plt.tight_layout()

    if output_path:
        plt.savefig(output_path, dpi=150, bbox_inches="tight")
        plt.close()
    else:
        plt.show()