Multi-Class Object Detection with NWPU-VHR-10¶
This notebook demonstrates end-to-end multi-class object detection using the NWPU-VHR-10 dataset, a benchmark for object detection in very high resolution (VHR) remote sensing imagery.
The dataset contains 800 images with 10 object classes:
- airplane, ship, storage tank, baseball diamond, tennis court
- basketball court, ground track field, harbor, bridge, vehicle
Install package¶
To use the geoai-py package, ensure it is installed in your environment. Uncomment the command below if needed.
# %pip install geoai-py
Import libraries¶
import json
import os
import geoai
Download NWPU-VHR-10 dataset¶
data_dir = geoai.download_nwpu_vhr10()
Explore the dataset¶
print(f"Dataset directory: {data_dir}")
print(f"Contents: {os.listdir(data_dir)}")
print(f"\nNWPU-VHR-10 Classes:")
for i, name in enumerate(geoai.NWPU_VHR10_CLASSES):
print(f" {i}: {name}")
Prepare dataset¶
Split the dataset into training and validation sets.
splits = geoai.prepare_nwpu_vhr10(data_dir, val_split=0.2, seed=42)
print(f"Images directory: {splits['images_dir']}")
print(f"Number of classes: {splits['num_classes']}")
print(f"Class names: {splits['class_names']}")
print(f"Training images: {len(splits['train_image_ids'])}")
print(f"Validation images: {len(splits['val_image_ids'])}")
Visualize sample annotations¶
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import numpy as np
from PIL import Image
# Load annotations
with open(splits["annotations_path"], "r") as f:
coco_data = json.load(f)
# Get a few sample images
sample_images = coco_data["images"][:4]
categories = {cat["id"]: cat["name"] for cat in coco_data["categories"]}
cmap = plt.cm.get_cmap("tab10", 10)
fig, axes = plt.subplots(2, 2, figsize=(14, 14))
axes = axes.flatten()
for ax_idx, img_info in enumerate(sample_images):
img_path = os.path.join(splits["images_dir"], img_info["file_name"])
img = Image.open(img_path)
axes[ax_idx].imshow(img)
axes[ax_idx].set_title(img_info["file_name"], fontsize=10)
axes[ax_idx].axis("off")
# Draw annotations for this image
img_anns = [
ann for ann in coco_data["annotations"] if ann["image_id"] == img_info["id"]
]
for ann in img_anns:
x, y, w, h = ann["bbox"]
cat_id = ann["category_id"]
color = cmap(cat_id % 10)
rect = plt.Rectangle(
(x, y), w, h, linewidth=2, edgecolor=color, facecolor="none"
)
axes[ax_idx].add_patch(rect)
axes[ax_idx].text(
x,
y - 3,
categories.get(cat_id, str(cat_id)),
color="white",
fontsize=7,
bbox=dict(boxstyle="round,pad=0.2", facecolor=color, alpha=0.7),
)
plt.tight_layout()
plt.show()
Use pretrained model from HuggingFace¶
A pretrained Mask R-CNN model for NWPU-VHR-10 is available on HuggingFace Hub. You can download it directly and run inference without training. If you prefer to train your own model, skip to the "Train multi-class detection model" section below.
model_path = geoai.download_nwpu_vhr10_model()
Run inference on a sample image using the pretrained model. The multiclass_detection function will use the NWPU-VHR-10 class names automatically when using the pretrained model.
# Pick a sample image from the dataset
sample_img_path = os.path.join(splits["images_dir"], "012.jpg")
output_raster = "nwpu_pretrained_output.tif"
result_path, inference_time, detections = geoai.multiclass_detection(
input_path=sample_img_path,
output_path=output_raster,
model_path=model_path,
confidence_threshold=0.5,
)
print(f"Inference time: {inference_time:.2f}s")
print(f"Total detections: {len(detections)}")
geoai.visualize_multiclass_detections(
image_path=sample_img_path,
detections=detections,
confidence_threshold=0.5,
figsize=(12, 10),
)
You can also call multiclass_detection without specifying model_path at all. It will automatically download the pretrained model and use the NWPU-VHR-10 class names.
result_path, inference_time, detections = geoai.multiclass_detection(
input_path=sample_img_path,
output_path="nwpu_auto_output.tif",
confidence_threshold=0.5,
)
print(f"Inference time: {inference_time:.2f}s")
print(f"Total detections: {len(detections)}")
# Clean up temporary output files
for f in ["nwpu_pretrained_output.tif", "nwpu_auto_output.tif"]:
if os.path.exists(f):
os.remove(f)
Train multi-class detection model (Optional)¶
Alternatively, you can train your own Mask R-CNN model from scratch on the NWPU-VHR-10 dataset. This section is optional if you are using the pretrained model above.
output_dir = "nwpu_output"
model_path = geoai.train_multiclass_detector(
images_dir=splits["images_dir"],
annotations_path=splits["train_annotations"],
output_dir=output_dir,
class_names=splits["class_names"],
num_channels=3,
batch_size=4,
num_epochs=20,
learning_rate=0.005,
val_split=0.15,
seed=42,
pretrained=True,
verbose=True,
)
Plot training metrics¶
import torch
history_path = os.path.join(output_dir, "training_history.pth")
if os.path.exists(history_path):
history = torch.load(history_path, weights_only=True)
fig, axes = plt.subplots(1, 3, figsize=(15, 4))
axes[0].plot(history["epochs"], history["train_loss"], label="Train Loss")
axes[0].plot(history["epochs"], history["val_loss"], label="Val Loss")
axes[0].set_xlabel("Epoch")
axes[0].set_ylabel("Loss")
axes[0].set_title("Training & Validation Loss")
axes[0].legend()
axes[1].plot(history["epochs"], history["val_iou"], label="Val IoU", color="green")
axes[1].set_xlabel("Epoch")
axes[1].set_ylabel("IoU")
axes[1].set_title("Validation IoU")
axes[1].legend()
axes[2].plot(
history["epochs"], history["lr"], label="Learning Rate", color="orange"
)
axes[2].set_xlabel("Epoch")
axes[2].set_ylabel("LR")
axes[2].set_title("Learning Rate Schedule")
axes[2].legend()
plt.tight_layout()
plt.show()
Evaluate model with COCO metrics¶
metrics = geoai.evaluate_multiclass_detector(
model_path=model_path,
images_dir=splits["images_dir"],
annotations_path=splits["val_annotations"],
num_classes=splits["num_classes"],
class_names=splits["class_names"][1:], # Exclude background
batch_size=4,
)
Run inference on sample images¶
# Pick a validation image for inference
with open(splits["val_annotations"], "r") as f:
val_data = json.load(f)
# Find an image with multiple object types
test_img_info = val_data["images"][0]
test_img_path = os.path.join(splits["images_dir"], test_img_info["file_name"])
print(f"Test image: {test_img_path}")
output_raster = "nwpu_detection_output.tif"
result_path, inference_time, detections = geoai.multiclass_detection(
input_path=test_img_path,
output_path=output_raster,
model_path=model_path,
num_classes=splits["num_classes"],
class_names=splits["class_names"],
window_size=512,
overlap=256,
confidence_threshold=0.5,
batch_size=4,
num_channels=3,
)
print(f"\nInference time: {inference_time:.2f}s")
print(f"Total detections: {len(detections)}")
Visualize detections¶
geoai.visualize_multiclass_detections(
image_path=test_img_path,
detections=detections,
class_names=splits["class_names"],
confidence_threshold=0.5,
figsize=(12, 10),
)
Batch inference on multiple validation images¶
# Run inference on a few validation images and display results
num_samples = min(4, len(val_data["images"]))
fig, axes = plt.subplots(2, 2, figsize=(16, 16))
axes = axes.flatten()
for idx in range(num_samples):
img_info = val_data["images"][idx]
img_path = os.path.join(splits["images_dir"], img_info["file_name"])
out_path = f"nwpu_detection_{idx}.tif"
_, _, dets = geoai.multiclass_detection(
input_path=img_path,
output_path=out_path,
model_path=model_path,
num_classes=splits["num_classes"],
class_names=splits["class_names"],
confidence_threshold=0.5,
num_channels=3,
)
# Display
img = Image.open(img_path)
axes[idx].imshow(img)
axes[idx].set_title(
f"{img_info['file_name']} ({len(dets)} detections)", fontsize=10
)
axes[idx].axis("off")
for det in dets:
box = det["box"]
label = det["label"]
score = det["score"]
color = cmap(label % 10)
rect = plt.Rectangle(
(box[0], box[1]),
box[2] - box[0],
box[3] - box[1],
linewidth=2,
edgecolor=color,
facecolor="none",
)
axes[idx].add_patch(rect)
name = (
splits["class_names"][label]
if label < len(splits["class_names"])
else str(label)
)
axes[idx].text(
box[0],
box[1] - 3,
f"{name}: {score:.2f}",
color="white",
fontsize=7,
bbox=dict(boxstyle="round,pad=0.2", facecolor=color, alpha=0.7),
)
# Clean up temp file
if os.path.exists(out_path):
os.remove(out_path)
plt.tight_layout()
plt.show()
Summary¶
In this notebook, we demonstrated:
- Downloading the NWPU-VHR-10 remote sensing object detection dataset
- Preparing train/validation splits from COCO-format annotations
- Using a pretrained model from HuggingFace Hub for instant inference
- Training a multi-class Mask R-CNN model for 10 object categories (optional)
- Evaluating the model using COCO-style mAP metrics
- Running inference on test images with multi-class detection
- Visualizing detection results with colored bounding boxes