Load Checkpoint and Resume Training for Semantic Segmentation Model¶
This notebook tests the new checkpoint loading and resume training functionality for the train_segmentation_model
function. It demonstrates how to:
- Train a model for a few epochs
- Stop training and save a checkpoint
- Resume training from the checkpoint
- Load only model weights without resuming training state
Install packages¶
To use the new functionality, ensure the required packages are installed.
# %pip install geoai-py
Import libraries¶
import geoai
import os
Download sample data¶
We'll use the same dataset as the main segmentation example.
train_raster_url = (
"https://huggingface.co/datasets/giswqs/geospatial/resolve/main/naip_rgb_train.tif"
)
train_vector_url = "https://huggingface.co/datasets/giswqs/geospatial/resolve/main/naip_train_buildings.geojson"
test_raster_url = (
"https://huggingface.co/datasets/giswqs/geospatial/resolve/main/naip_test.tif"
)
train_raster_path = geoai.download_file(train_raster_url)
train_vector_path = geoai.download_file(train_vector_url)
test_raster_path = geoai.download_file(test_raster_url)
Create training data¶
out_folder = "checkpoint_test"
tiles = geoai.export_geotiff_tiles(
in_raster=train_raster_path,
out_folder=out_folder,
in_class_data=train_vector_path,
tile_size=512,
stride=256,
buffer_radius=0,
)
Initial Training (First 10 epochs)¶
First, we'll train a model for just 10 epochs and save checkpoints.
# Initial training - just 10 epochs
print("Starting initial training for 10 epochs...")
geoai.train_segmentation_model(
images_dir=f"{out_folder}/images",
labels_dir=f"{out_folder}/labels",
output_dir=f"{out_folder}/initial_training",
architecture="unet",
encoder_name="resnet34",
encoder_weights="imagenet",
num_channels=3,
num_classes=2,
batch_size=4, # Smaller batch size for faster testing
num_epochs=10, # Just 10 epochs for testing
learning_rate=0.001,
val_split=0.2,
save_best_only=False, # Save checkpoints every 10 epochs
verbose=True,
plot_curves=True,
)
print("Initial training completed!")
Resume Training from Checkpoint¶
Now we'll resume training from the checkpoint, continuing for another 10 epochs (total 20 epochs).
# Check if checkpoint exists
checkpoint_path = f"{out_folder}/initial_training/checkpoint_epoch_10.pth"
if os.path.exists(checkpoint_path):
print(f"Found checkpoint: {checkpoint_path}")
else:
print(f"Checkpoint not found: {checkpoint_path}")
print("Available files:")
for f in os.listdir(f"{out_folder}/initial_training"):
print(f" {f}")
# Resume training from checkpoint
print("Resuming training from checkpoint...")
geoai.train_segmentation_model(
images_dir=f"{out_folder}/images",
labels_dir=f"{out_folder}/labels",
output_dir=f"{out_folder}/resumed_training",
architecture="unet",
encoder_name="resnet34",
encoder_weights="imagenet",
num_channels=3,
num_classes=2,
batch_size=4,
num_epochs=20, # Total epochs (will resume from epoch 10)
learning_rate=0.001,
val_split=0.2,
save_best_only=False,
verbose=True,
plot_curves=True,
checkpoint_path=checkpoint_path,
resume_training=True, # Resume training state
)
print("Resumed training completed!")
Load Model Weights Only (No Training State Resume)¶
Finally, we'll test loading only the model weights without resuming the training state.
# Load weights only (no training state)
print("Loading model weights only (not resuming training state)...")
geoai.train_segmentation_model(
images_dir=f"{out_folder}/images",
labels_dir=f"{out_folder}/labels",
output_dir=f"{out_folder}/weights_only_training",
architecture="unet",
encoder_name="resnet34",
encoder_weights="imagenet",
num_channels=3,
num_classes=2,
batch_size=4,
num_epochs=15, # Train for 15 epochs starting from epoch 0
learning_rate=0.001,
val_split=0.2,
save_best_only=False,
verbose=True,
plot_curves=True,
checkpoint_path=checkpoint_path,
resume_training=False, # Don't resume training state
)
print("Weights-only training completed!")
Load Best Model for Inference¶
Test loading the best model for inference.
# Test inference with resumed model
masks_path = f"{out_folder}/test_prediction.tif"
model_path = f"{out_folder}/resumed_training/best_model.pth"
if os.path.exists(model_path):
print(f"Running inference with model: {model_path}")
geoai.semantic_segmentation(
input_path=test_raster_path,
output_path=masks_path,
model_path=model_path,
architecture="unet",
encoder_name="resnet34",
num_channels=3,
num_classes=2,
window_size=512,
overlap=256,
batch_size=4,
)
print(f"Inference completed. Results saved to: {masks_path}")
else:
print(f"Model not found: {model_path}")
Compare Training Histories¶
Let's compare the training histories to verify that resuming worked correctly.
import torch
import matplotlib.pyplot as plt
# Load training histories
initial_history_path = f"{out_folder}/initial_training/training_history.pth"
resumed_history_path = f"{out_folder}/resumed_training/training_history.pth"
weights_only_history_path = f"{out_folder}/weights_only_training/training_history.pth"
histories = {}
for name, path in [
("Initial (10 epochs)", initial_history_path),
("Resumed (10→20 epochs)", resumed_history_path),
("Weights Only (15 epochs)", weights_only_history_path),
]:
if os.path.exists(path):
histories[name] = torch.load(path)
print(f"Loaded {name}: {len(histories[name]['train_losses'])} epochs")
else:
print(f"History not found: {path}")
# Plot comparison - showing the continuation clearly
if histories:
fig, axes = plt.subplots(1, 3, figsize=(18, 5))
# Get the individual histories
initial_history = histories["Initial (10 epochs)"]
resumed_history = histories["Resumed (10→20 epochs)"]
weights_only_history = histories["Weights Only (15 epochs)"]
# Training Loss - show continuation clearly
initial_epochs = range(1, 11)
axes[0].plot(
initial_epochs,
initial_history["train_losses"],
label="Initial (epochs 1-10)",
marker="o",
markersize=5,
color="blue",
linewidth=3,
alpha=0.9,
zorder=3,
)
# Plot resumed training continuation (epochs 11-20 only)
resumed_epochs = range(11, 21)
resumed_continuation = resumed_history["train_losses"][10:] # epochs 11-20
axes[0].plot(
resumed_epochs,
resumed_continuation,
label="Resumed (epochs 11-20)",
marker="s",
markersize=5,
color="orange",
linewidth=3,
alpha=0.9,
zorder=2,
)
# Plot weights only
weights_epochs = range(1, 16)
axes[0].plot(
weights_epochs,
weights_only_history["train_losses"],
label="Weights Only (epochs 1-15)",
marker="^",
markersize=4,
color="green",
linewidth=2,
alpha=0.7,
zorder=1,
)
# Add continuation line
axes[0].plot(
[10, 11],
[initial_history["train_losses"][-1], resumed_continuation[0]],
color="red",
linewidth=2,
linestyle="--",
alpha=0.7,
label="Continuation",
)
axes[0].set_title("Training Loss", fontsize=14)
axes[0].set_xlabel("Epoch")
axes[0].set_ylabel("Loss")
axes[0].legend()
axes[0].grid(True, alpha=0.3)
# Validation IoU
axes[1].plot(
initial_epochs,
initial_history["val_ious"],
label="Initial (epochs 1-10)",
marker="o",
markersize=5,
color="blue",
linewidth=3,
alpha=0.9,
zorder=3,
)
resumed_iou_continuation = resumed_history["val_ious"][10:]
axes[1].plot(
resumed_epochs,
resumed_iou_continuation,
label="Resumed (epochs 11-20)",
marker="s",
markersize=5,
color="orange",
linewidth=3,
alpha=0.9,
zorder=2,
)
axes[1].plot(
weights_epochs,
weights_only_history["val_ious"],
label="Weights Only (epochs 1-15)",
marker="^",
markersize=4,
color="green",
linewidth=2,
alpha=0.7,
zorder=1,
)
axes[1].plot(
[10, 11],
[initial_history["val_ious"][-1], resumed_iou_continuation[0]],
color="red",
linewidth=2,
linestyle="--",
alpha=0.7,
label="Continuation",
)
axes[1].set_title("Validation IoU", fontsize=14)
axes[1].set_xlabel("Epoch")
axes[1].set_ylabel("IoU")
axes[1].legend()
axes[1].grid(True, alpha=0.3)
# Validation Dice
axes[2].plot(
initial_epochs,
initial_history["val_dices"],
label="Initial (epochs 1-10)",
marker="o",
markersize=5,
color="blue",
linewidth=3,
alpha=0.9,
zorder=3,
)
resumed_dice_continuation = resumed_history["val_dices"][10:]
axes[2].plot(
resumed_epochs,
resumed_dice_continuation,
label="Resumed (epochs 11-20)",
marker="s",
markersize=5,
color="orange",
linewidth=3,
alpha=0.9,
zorder=2,
)
axes[2].plot(
weights_epochs,
weights_only_history["val_dices"],
label="Weights Only (epochs 1-15)",
marker="^",
markersize=4,
color="green",
linewidth=2,
alpha=0.7,
zorder=1,
)
axes[2].plot(
[10, 11],
[initial_history["val_dices"][-1], resumed_dice_continuation[0]],
color="red",
linewidth=2,
linestyle="--",
alpha=0.7,
label="Continuation",
)
axes[2].set_title("Validation Dice", fontsize=14)
axes[2].set_xlabel("Epoch")
axes[2].set_ylabel("Dice")
axes[2].legend()
axes[2].grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig(
f"{out_folder}/clear_training_comparison.png", dpi=150, bbox_inches="tight"
)
plt.show()
# Print explanation
print("\n" + "=" * 60)
print("CHECKPOINT RESUME VISUALIZATION EXPLANATION")
print("=" * 60)
print("• Blue line: Initial training (epochs 1-10)")
print("• Orange line: Resumed training continuation (epochs 11-20)")
print("• Green line: Weights-only training (fresh epochs 1-15)")
print("• Red dashed line: Shows seamless continuation")
print()
print("NOTE: In the original overlapping plot, the blue line was")
print("completely covered by the orange line because resumed training")
print("includes the exact same first 10 epochs plus 10 additional epochs.")
print("=" * 60)
# Print summary statistics
print("\nTraining Summary:")
for name, history in histories.items():
final_iou = history["val_ious"][-1]
final_dice = history["val_dices"][-1]
final_loss = history["train_losses"][-1]
epochs_trained = len(history["train_losses"])
print(f"{name}: {epochs_trained} epochs")
print(f" Final IoU: {final_iou:.4f}")
print(f" Final Dice: {final_dice:.4f}")
print(f" Final Loss: {final_loss:.4f}")
print()
print("Training comparison plot saved to clear_training_comparison.png")
else:
print("No training histories found for comparison")
Summary¶
This notebook demonstrates the new checkpoint functionality:
- Initial Training: Trained for 10 epochs and saved checkpoints
- Resume Training: Successfully resumed from epoch 10 and continued to epoch 20
- Weights Only: Loaded model weights but started training from epoch 0
- Inference: Used the resumed model for inference
The key features tested:
checkpoint_path
: Path to the checkpoint fileresume_training=True
: Resume training state (epoch, optimizer, scheduler, metrics)resume_training=False
: Load only model weights
This functionality is useful for:
- Long training jobs that might be interrupted
- Experimenting with different training parameters after initial training
- Transfer learning from partially trained models