Train an Image Classifier with TIMM Models¶
This notebook demonstrates how to train an image classification model using the PyTorch Image Models (timm) library. The geoai.timm_train
module provides a high-level API for training state-of-the-art computer vision models on remote sensing imagery.
Key Features¶
- 1000+ Pre-trained Models: Access to ResNet, EfficientNet, Vision Transformers (ViT), ConvNeXt, and more
- Multi-channel Support: Train on RGB, RGBN (RGB + NIR), or any number of channels
- PyTorch Lightning Integration: Automatic training loops, checkpointing, and early stopping
- Transfer Learning: Fine-tune pretrained models or train from scratch
Install packages¶
To use the new functionality, ensure the required packages are installed.
# %pip install geoai-py timm lightning datasets
Import libraries¶
import os
import geoai
from geoai.timm_train import (
list_timm_models,
get_timm_model,
RemoteSensingDataset,
train_timm_classifier,
predict_with_timm,
)
Explore Available Models¶
The timm library provides over 1000 pretrained models. Let's explore some popular architectures:
# List ResNet models
resnet_models = list_timm_models(filter="resnet", limit=10)
print("ResNet models:", resnet_models)
# List EfficientNet models
efficientnet_models = list_timm_models(filter="efficientnet", limit=10)
print("EfficientNet models:", efficientnet_models)
# List Vision Transformer models
vit_models = list_timm_models(filter="vit", limit=10)
print("Vision Transformer models:", vit_models)
Download Sample Data¶
For this example, we'll use the EuroSAT RGB dataset from Hugging Face. This dataset contains Sentinel-2 satellite RGB images in 10 land use/land cover classes:
- AnnualCrop
- Forest
- HerbaceousVegetation
- Highway
- Industrial
- Pasture
- PermanentCrop
- Residential
- River
- SeaLake
from datasets import load_dataset
import tempfile
import shutil
from PIL import Image
# Load EuroSAT RGB dataset from Hugging Face
print("Loading EuroSAT dataset from Hugging Face...")
dataset = load_dataset("timm/eurosat-rgb", split="train")
# Create a temporary directory to save images
temp_dir = tempfile.mkdtemp(prefix="eurosat_")
print(f"Saving images to: {temp_dir}")
# Save images to disk organized by class
class_names = dataset.features["label"].names
print(f"Classes: {class_names}")
for idx, sample in enumerate(dataset):
img = sample["image"]
label = sample["label"]
class_name = class_names[label]
# Create class directory
class_dir = os.path.join(temp_dir, class_name)
os.makedirs(class_dir, exist_ok=True)
# Save image as JPEG
img_path = os.path.join(class_dir, f"{idx:05d}.jpg")
img.save(img_path)
print(f"Saved {len(dataset)} images to {temp_dir}")
Prepare Training Data¶
Now we'll load all image paths and create train/val/test splits.
import glob
from sklearn.model_selection import train_test_split
# Get all image paths and labels
image_paths = []
labels = []
for class_idx, class_name in enumerate(class_names):
class_dir = os.path.join(temp_dir, class_name)
class_images = sorted(glob.glob(os.path.join(class_dir, "*.jpg")))
image_paths.extend(class_images)
labels.extend([class_idx] * len(class_images))
print(f"Total images: {len(image_paths)}")
print(f"Number of classes: {len(class_names)}")
print(f"Class distribution:")
for class_idx, class_name in enumerate(class_names):
count = labels.count(class_idx)
print(f" {class_name}: {count}")
Split data into train, validation, and test sets¶
train_paths, test_paths, train_labels, test_labels = train_test_split(
image_paths, labels, test_size=0.2, random_state=42, stratify=labels
)
train_paths, val_paths, train_labels, val_labels = train_test_split(
train_paths, train_labels, test_size=0.2, random_state=42, stratify=train_labels
)
print(f"Training samples: {len(train_paths)}")
print(f"Validation samples: {len(val_paths)}")
print(f"Test samples: {len(test_paths)}")
import matplotlib.pyplot as plt
from PIL import Image
# Show one sample from each class
fig, axes = plt.subplots(2, 5, figsize=(20, 8))
for idx, class_name in enumerate(class_names):
ax = axes[idx // 5, idx % 5]
# Find first image of this class
img_idx = labels.index(idx)
img = Image.open(image_paths[img_idx])
ax.imshow(img)
ax.set_title(class_name, fontsize=12)
ax.axis("off")
plt.tight_layout()
plt.show()
Create Datasets¶
The RemoteSensingDataset
class handles loading images with support for multi-channel imagery.
# Create datasets
train_dataset = RemoteSensingDataset(
image_paths=train_paths,
labels=train_labels,
num_channels=3, # RGB images
)
val_dataset = RemoteSensingDataset(
image_paths=val_paths,
labels=val_labels,
num_channels=3,
)
test_dataset = RemoteSensingDataset(
image_paths=test_paths,
labels=test_labels,
num_channels=3,
)
print(f"Train dataset size: {len(train_dataset)}")
print(f"Validation dataset size: {len(val_dataset)}")
print(f"Test dataset size: {len(test_dataset)}")
Train a ResNet50 Classifier¶
Let's train a ResNet50 model with pretrained ImageNet weights for transfer learning on the 10-class EuroSAT dataset.
# Train ResNet50 classifier
output_dir = "timm_output/resnet50"
model = train_timm_classifier(
train_dataset=train_dataset,
val_dataset=val_dataset,
test_dataset=test_dataset,
model_name="resnet50",
num_classes=len(class_names), # 10 classes
in_channels=3,
pretrained=True,
output_dir=output_dir,
batch_size=32,
num_epochs=20,
learning_rate=1e-3,
weight_decay=1e-4,
num_workers=4,
freeze_backbone=False,
monitor_metric="val_acc",
mode="max",
patience=5,
save_top_k=1,
)
Train an EfficientNet-B0 Classifier¶
EfficientNet models provide an excellent balance between accuracy and efficiency.
# Train EfficientNet-B0 classifier
output_dir = "timm_output/efficientnet_b0"
model = train_timm_classifier(
train_dataset=train_dataset,
val_dataset=val_dataset,
test_dataset=test_dataset,
model_name="efficientnet_b0",
num_classes=len(class_names),
in_channels=3,
pretrained=True,
output_dir=output_dir,
batch_size=32,
num_epochs=20,
learning_rate=1e-3,
weight_decay=1e-4,
num_workers=4,
freeze_backbone=False,
monitor_metric="val_acc",
mode="max",
patience=5,
save_top_k=1,
)
Fine-tuning with Frozen Backbone¶
For faster training, you can freeze the backbone and only train the classification head:
# Fine-tune only the classifier head
output_dir = "timm_output/resnet50_frozen"
model_frozen = train_timm_classifier(
train_dataset=train_dataset,
val_dataset=val_dataset,
test_dataset=test_dataset,
model_name="resnet50",
num_classes=len(class_names),
in_channels=3,
pretrained=True,
freeze_backbone=True, # Freeze backbone weights
output_dir=output_dir,
batch_size=32,
num_epochs=10, # Fewer epochs needed
learning_rate=1e-3,
monitor_metric="val_acc",
mode="max",
)
Make Predictions¶
Use the trained model to make predictions on test images.
# Load the best model checkpoint
from geoai.timm_train import TimmClassifier
import torch
# Path to the best model checkpoint
checkpoint_path = "timm_output/resnet50/models/last.ckpt"
# Load model
model = TimmClassifier.load_from_checkpoint(checkpoint_path)
# Make predictions
predictions, probabilities = predict_with_timm(
model=model,
image_paths=test_paths[:20], # Predict on first 20 test images
batch_size=8,
return_probabilities=True,
)
print(f"Predictions shape: {predictions.shape}")
print(f"Probabilities shape: {probabilities.shape}")
print(f"Sample predictions: {[class_names[p] for p in predictions[:5]]}")
Visualize Predictions¶
import matplotlib.pyplot as plt
from PIL import Image
# Visualize predictions
fig, axes = plt.subplots(4, 5, figsize=(20, 16))
for idx, ax in enumerate(axes.flat):
if idx >= len(test_paths[:20]):
break
# Load and display image
img = Image.open(test_paths[idx])
ax.imshow(img)
pred_class = class_names[predictions[idx]]
true_class = class_names[test_labels[idx]]
confidence = probabilities[idx][predictions[idx]] * 100
color = "green" if predictions[idx] == test_labels[idx] else "red"
ax.set_title(
f"Pred: {pred_class}\nTrue: {true_class}\n({confidence:.1f}%)",
color=color,
fontsize=10,
)
ax.axis("off")
plt.tight_layout()
plt.show()
Using Class Weights for Imbalanced Datasets¶
When dealing with imbalanced datasets, you can provide class weights to the loss function:
from sklearn.utils.class_weight import compute_class_weight
import numpy as np
# Compute class weights
class_weights = compute_class_weight(
class_weight="balanced", classes=np.unique(train_labels), y=train_labels
)
print(f"Class weights: {class_weights}")
# Train with class weights
output_dir = "timm_output/resnet50_weighted"
model_weighted = train_timm_classifier(
train_dataset=train_dataset,
val_dataset=val_dataset,
model_name="resnet50",
num_classes=len(class_names),
in_channels=3,
pretrained=True,
output_dir=output_dir,
batch_size=32,
num_epochs=20,
learning_rate=1e-3,
class_weights=class_weights.tolist(), # Pass class weights
monitor_metric="val_acc",
mode="max",
)
Summary¶
This notebook demonstrated:
- Model Selection: Exploring 1000+ available timm models (ResNet, EfficientNet, ViT)
- Data Loading: Using the EuroSAT RGB dataset from Hugging Face
- Training: Training various architectures on 10-class land cover classification
- Transfer Learning: Fine-tuning pretrained models with frozen backbones
- Inference: Making predictions and visualizations
- Class Weighting: Handling imbalanced datasets
Key Parameters¶
model_name
: Choose from 1000+ timm modelsnum_classes
: Number of output classesin_channels
: Number of input channels (3 for RGB, 4 for RGBN, etc.)pretrained
: Use ImageNet pretrained weights for transfer learningfreeze_backbone
: Freeze backbone for faster fine-tuningclass_weights
: Handle imbalanced datasetsmonitor_metric
: Track 'val_loss' or 'val_acc' for checkpointingpatience
: Early stopping patience
Next Steps¶
- Experiment with different model architectures (ConvNeXt, Swin Transformer, etc.)
- Try data augmentation for improved performance
- Use learning rate schedulers for better convergence
- Deploy models for inference on satellite imagery
from sklearn.utils.class_weight import compute_class_weight
import numpy as np
# Compute class weights
class_weights = compute_class_weight(
class_weight="balanced", classes=np.unique(train_labels), y=train_labels
)
print(f"Class weights: {class_weights}")
# Train with class weights
output_dir = "timm_output/resnet50_weighted"
model_weighted = train_timm_classifier(
train_dataset=train_dataset,
val_dataset=val_dataset,
model_name="resnet50",
num_classes=len(class_names),
in_channels=3,
pretrained=True,
output_dir=output_dir,
batch_size=16,
num_epochs=20,
learning_rate=1e-3,
class_weights=class_weights.tolist(), # Pass class weights
monitor_metric="val_acc",
mode="max",
)
Summary¶
This notebook demonstrated:
- Model Selection: Exploring 1000+ available timm models
- Data Preparation: Creating datasets for remote sensing imagery
- Training: Training various architectures (ResNet, EfficientNet, ViT)
- Multi-channel Support: Handling 4-band RGBN imagery
- Transfer Learning: Fine-tuning pretrained models with frozen backbones
- Inference: Making predictions on new images
- Class Weighting: Handling imbalanced datasets
Key Parameters¶
model_name
: Choose from 1000+ timm modelsnum_classes
: Number of output classesin_channels
: Number of input channels (3 for RGB, 4 for RGBN, etc.)pretrained
: Use ImageNet pretrained weights for transfer learningfreeze_backbone
: Freeze backbone for faster fine-tuningclass_weights
: Handle imbalanced datasetsmonitor_metric
: Track 'val_loss' or 'val_acc' for checkpointingpatience
: Early stopping patience
Next Steps¶
- Experiment with different model architectures
- Try data augmentation for improved performance
- Use learning rate schedulers for better convergence
- Deploy models for inference on large raster datasets