Train a Semantic Segmentation Model with TIMM Encoders¶
This notebook demonstrates how to train semantic segmentation models using PyTorch Image Models (timm) encoders. This approach combines:
- 1000+ TIMM Encoders: State-of-the-art backbones (ResNet, EfficientNet, ViT, ConvNeXt, etc.)
- 9 Architectures: U-Net, U-Net++, DeepLabV3+, FPN, PSPNet, LinkNet, MANet, PAN
- Multi-channel Support: RGB, RGBN, or any number of input channels
- Simplified API: Similar to
train_segmentation_model
for ease of use
Install packages¶
To use the new functionality, ensure the required packages are installed.
# %pip install geoai-py timm segmentation-models-pytorch lightning
Import libraries¶
import geoai
Explore Available Encoders¶
The timm library provides 1000+ encoders that can be used with segmentation architectures:
# List some popular encoders
print("ResNet encoders:", geoai.list_timm_models(filter="resnet", limit=5))
print("EfficientNet encoders:", geoai.list_timm_models(filter="efficientnet", limit=5))
print("ConvNeXt encoders:", geoai.list_timm_models(filter="convnext", limit=5))
Download Sample Data¶
We'll use the same NAIP building detection dataset as the train_segmentation_model
example.
# Download NAIP aerial imagery and building footprints
train_raster_url = (
"https://huggingface.co/datasets/giswqs/geospatial/resolve/main/naip_rgb_train.tif"
)
train_vector_url = "https://huggingface.co/datasets/giswqs/geospatial/resolve/main/naip_train_buildings.geojson"
test_raster_url = (
"https://huggingface.co/datasets/giswqs/geospatial/resolve/main/naip_test.tif"
)
train_raster_path = geoai.download_file(train_raster_url)
train_vector_path = geoai.download_file(train_vector_url)
test_raster_path = geoai.download_file(test_raster_url)
Visualize Sample Data¶
geoai.view_vector_interactive(train_vector_path, tiles=train_raster_path)
Create Training Data¶
Generate image chips and corresponding segmentation masks:
out_folder = "timm_buildings"
# Create training chips
tiles = geoai.export_geotiff_tiles(
in_raster=train_raster_path,
out_folder=out_folder,
in_class_data=train_vector_path,
tile_size=512,
stride=256,
buffer_radius=0,
)
Train U-Net with ResNet50 Encoder¶
The train_timm_segmentation_model
function provides a simplified interface similar to train_segmentation_model
:
# Train U-Net with ResNet50 encoder
geoai.train_timm_segmentation_model(
images_dir=f"{out_folder}/images",
labels_dir=f"{out_folder}/labels",
output_dir=f"{out_folder}/unet_resnet50",
encoder_name="resnet50",
architecture="unet",
encoder_weights="imagenet",
num_channels=3,
num_classes=2, # background and building
batch_size=8,
num_epochs=20,
learning_rate=0.001,
val_split=0.2,
verbose=True,
)
Train DeepLabV3+ with EfficientNet-B3 Encoder¶
EfficientNet encoders provide excellent performance with fewer parameters:
# Train DeepLabV3+ with EfficientNet-B3
geoai.train_timm_segmentation_model(
images_dir=f"{out_folder}/images",
labels_dir=f"{out_folder}/labels",
output_dir=f"{out_folder}/deeplabv3plus_efficientnet_b3",
encoder_name="efficientnet-b3", # Note: use dash for SMP encoders
architecture="deeplabv3plus",
encoder_weights="imagenet",
num_channels=3,
num_classes=2,
batch_size=8,
num_epochs=20,
learning_rate=0.001,
val_split=0.2,
verbose=True,
)
Model Performance Analysis¶
Let's examine the training curves and model performance:
geoai.plot_performance_metrics(
history_path=f"{out_folder}/unet_resnet50/models/training_history.pth",
figsize=(15, 5),
verbose=True,
)
Performance Metrics¶
IoU (Intersection over Union) is the primary metric used to evaluate semantic segmentation performance.
🔸 IoU Definition¶
$$ \text{IoU} = \frac{|A \cap B|}{|A \cup B|} = \frac{TP}{TP + FP + FN} $$
- Measures the overlap between predicted region $A$ and ground truth region $B$ relative to their union
- Ranges from 0 (no overlap) to 1 (perfect overlap)
- Common in object detection and semantic segmentation benchmarks (e.g., COCO, Pascal VOC)
The training curves show:
- Training Loss: How well the model fits the training data
- Validation Loss: How well the model generalizes to unseen data
- Validation IoU: The overlap accuracy on validation data
Note: Higher IoU is better (closer to 1.0), lower loss is better (closer to 0)
# Fine-tune with frozen encoder
geoai.train_timm_segmentation_model(
images_dir=f"{out_folder}/images",
labels_dir=f"{out_folder}/labels",
output_dir=f"{out_folder}/unet_resnet50_frozen",
encoder_name="resnet50",
architecture="unet",
encoder_weights="imagenet",
num_channels=3,
num_classes=2,
freeze_encoder=True, # Freeze encoder weights
batch_size=8,
num_epochs=10, # Fewer epochs needed
learning_rate=0.001,
val_split=0.2,
verbose=True,
)
Run Inference on Test Image¶
Use the trained model to segment the test image using the timm_semantic_segmentation
function:
# Run inference
masks_path = "naip_test_timm_prediction.tif"
model_path = f"{out_folder}/unet_resnet50/models/last.ckpt"
geoai.timm_semantic_segmentation(
input_path=test_raster_path,
output_path=masks_path,
model_path=model_path,
encoder_name="resnet50",
architecture="unet",
num_channels=3,
num_classes=2,
window_size=512,
overlap=256,
batch_size=4,
)
Vectorize and Visualize Results¶
Convert the segmentation mask to vector format and visualize:
# Vectorize the mask
output_vector_path = "naip_test_timm_prediction.geojson"
gdf = geoai.orthogonalize(masks_path, output_vector_path, epsilon=2)
# Add geometric properties
gdf_props = geoai.add_geometric_properties(gdf, area_unit="m2", length_unit="m")
# Visualize results
geoai.view_raster(masks_path, nodata=0, basemap=test_raster_path, backend="ipyleaflet")
# Filter buildings by area and visualize
gdf_filtered = gdf_props[gdf_props["area_m2"] > 50]
geoai.view_vector_interactive(gdf_filtered, column="area_m2", tiles=test_raster_path)
# Create split map comparison
geoai.create_split_map(
left_layer=gdf_filtered,
right_layer=test_raster_path,
left_args={"style": {"color": "red", "fillOpacity": 0.2}},
basemap=test_raster_path,
)
Hugging Face Hub Integration¶
The geoai library now supports loading models from and pushing models to the Hugging Face Hub. This enables:
- Loading Pre-trained Models: Use state-of-the-art segmentation models from HF Hub
- Sharing Your Models: Upload trained models to share with the community
- Model Versioning: Leverage HF Hub's versioning and collaboration features
Option 1: Load a Pre-trained Model from HF Hub¶
You can use complete segmentation models from Hugging Face Hub with the use_timm_model=True
parameter:
Note: This example is commented out as it requires a specific HF Hub model. Uncomment if you have a compatible model available.
# # Example: Load and fine-tune a model from HF Hub
# geoai.train_timm_segmentation_model(
# images_dir=f"{out_folder}/images",
# labels_dir=f"{out_folder}/labels",
# output_dir=f"{out_folder}/hf_hub_model",
# use_timm_model=True,
# timm_model_name="hf-hub:username/model-name", # Replace with actual HF Hub model
# num_channels=3,
# num_classes=2,
# batch_size=8,
# num_epochs=10,
# learning_rate=0.0001, # Lower LR for fine-tuning
# val_split=0.2,
# verbose=True,
# )
Option 2: Load a Pushed Model for Inference¶
Once a model is pushed to HF Hub, you can load it for inference:
# # Example: Load model from HF Hub and run inference
# # First, download the model from HF Hub (this would happen automatically)
# # Then run inference with the downloaded model
# geoai.timm_semantic_segmentation(
# input_path=test_raster_path,
# output_path="naip_test_hf_hub_prediction.tif",
# model_path="path/to/downloaded/model.pth", # Path to downloaded HF Hub model
# encoder_name="resnet50",
# architecture="unet",
# num_channels=3,
# num_classes=2,
# use_timm_model=False, # Set to True if it's a pure timm model
# window_size=512,
# overlap=256,
# batch_size=4,
# )
Option 3: Push Your Trained Model to HF Hub¶
After training a model, you can share it on Hugging Face Hub:
Note: This requires you to be logged in to Hugging Face. Run huggingface-cli login
in your terminal first.
# # Example: Push your trained model to HF Hub
# url = geoai.push_timm_model_to_hub(
# model_path=f"{out_folder}/unet_resnet50/models/last.ckpt",
# repo_id="your-username/building-segmentation-resnet50-unet", # Replace with your repo
# encoder_name="resnet50",
# architecture="unet",
# num_channels=3,
# num_classes=2,
# commit_message="Upload building segmentation model trained on NAIP imagery",
# private=False, # Set to True for private repository
# )
# print(f"Model uploaded to: {url}")
Hugging Face Hub Integration¶
The geoai library now supports loading models from and pushing models to the Hugging Face Hub. This enables:
- Loading Pre-trained Models: Use state-of-the-art segmentation models from HF Hub
- Sharing Your Models: Upload trained models to share with the community
- Model Versioning: Leverage HF Hub's versioning and collaboration features
Summary¶
This notebook demonstrated:
- Simplified API: Using
train_timm_segmentation_model()
similar totrain_segmentation_model()
- Multiple Encoders: Training with ResNet50, EfficientNet-B3, and ConvNeXt-Tiny
- Multiple Architectures: U-Net, DeepLabV3+, and FPN
- Transfer Learning: Fine-tuning with frozen encoders
- Inference: Using
timm_semantic_segmentation()
for predictions - Post-processing: Vectorization and visualization
Supported Architectures¶
- U-Net: Classic encoder-decoder architecture
- U-Net++: Nested U-Net with dense skip connections
- DeepLabV3: Atrous Spatial Pyramid Pooling (ASPP)
- DeepLabV3+: DeepLabV3 with decoder
- FPN: Feature Pyramid Network
- PSPNet: Pyramid Scene Parsing Network
- LinkNet: Efficient architecture with skip connections
- MANet: Multi-scale Attention Network
- PAN: Pyramid Attention Network
Popular Encoders¶
- ResNet family: resnet18, resnet34, resnet50, resnet101, resnet152
- EfficientNet family: efficientnet_b0 to efficientnet_b7
- ConvNeXt family: convnext_tiny, convnext_small, convnext_base
- RegNet family: regnetx_002, regnetx_004, regnety_002, regnety_004
- MobileNet family: mobilenetv2_100, mobilenetv3_large_100
Key Advantages¶
- 1000+ Encoders: Access to state-of-the-art backbones from timm
- Simple API: Functions match the existing
train_segmentation_model
interface - Automatic Preprocessing: Handles data loading and splitting automatically
- Lightning Integration: Built-in checkpointing, early stopping, and logging
- IoU Monitoring: Track IoU metrics during training
Next Steps¶
- Experiment with different encoder-architecture combinations
- Try modern encoders like ConvNeXt or Swin Transformer
- Use data augmentation for better generalization
- Apply to multi-class segmentation tasks