Self-Supervised Learning with Lightly Train¶
This notebook demonstrates how to use Lightly Train for self-supervised pretraining on unlabeled geospatial imagery. You'll learn how to:
- Train a self-supervised model using various methods (DINOv2, DINO, SimCLR)
- Load and use the pretrained model
- Generate embeddings for downstream tasks
Self-supervised learning is particularly useful for geospatial applications where labeled data is scarce or expensive to obtain.
Install packages¶
To use the geoai-py
package with Lightly Train support, ensure it is installed in your environment. Uncomment the command below if needed.
# %pip install geoai-py lightly-train
Import libraries¶
import geoai
import os
import numpy as np
import torch
from pathlib import Path
Download sample data¶
We'll use unlabeled satellite imagery for self-supervised pretraining. For this example, we'll download some sample NAIP imagery.
# Download sample geospatial imagery for training
sample_urls = [
"https://huggingface.co/datasets/giswqs/geospatial/resolve/main/naip_rgb_train.tif",
"https://huggingface.co/datasets/giswqs/geospatial/resolve/main/cars_7cm.tif",
]
# Create directories for our workflow
data_dir = "lightly_train_data"
output_dir = "lightly_train_output"
embeddings_dir = "lightly_embeddings"
os.makedirs(data_dir, exist_ok=True)
os.makedirs(output_dir, exist_ok=True)
os.makedirs(embeddings_dir, exist_ok=True)
# Download and prepare training images
image_paths = []
for url in sample_urls:
image_path = geoai.download_file(url)
image_paths.append(image_path)
print(f"Downloaded: {image_path}")
Prepare training data¶
For self-supervised learning, we need to extract image patches from our geospatial imagery. These patches will be used as unlabeled training data.
# # Extract image patches for training
# for i, image_path in enumerate(image_paths):
# patch_output_dir = f"{data_dir}/patches_{i}"
# # Extract patches from the raster
# geoai.export_geotiff_tiles(
# in_raster=image_path,
# out_folder=patch_output_dir,
# tile_size=224, # Common size for vision models
# stride=112, # 50% overlap
# )
# print(f"Extracted patches from {image_path} to {patch_output_dir}")
# Count total training images
import glob
all_patches = glob.glob(f"{data_dir}/**/images/*.tif", recursive=True)
print(f"\nTotal training patches: {len(all_patches)}")
Visualize sample training data¶
Let's take a look at some of the extracted patches that will be used for training.
import matplotlib.pyplot as plt
from PIL import Image
# Display a few sample patches
sample_patches = all_patches[:6] # Take first 6 patches
fig, axes = plt.subplots(2, 3, figsize=(12, 8))
fig.suptitle("Sample Training Patches for Self-Supervised Learning", fontsize=16)
for i, patch_path in enumerate(sample_patches):
row, col = i // 3, i % 3
img = Image.open(patch_path)
axes[row, col].imshow(img)
axes[row, col].set_title(f"Patch {i+1}")
axes[row, col].axis("off")
plt.tight_layout()
plt.show()
print(f"Each patch is {img.size} pixels")
Train self-supervised model¶
Now we'll train a self-supervised model using Lightly Train. We'll use the SimCLR method, which works well with CNN models like ResNet.
Method compatibility:
- SimCLR or DINO: Works with CNN models (ResNet, EfficientNet, etc.)
- DINOv2: Requires Vision Transformer (ViT) models only
Note: Training will take some time depending on your hardware. For demonstration purposes, we're using a small number of epochs. In practice, you might want to train for 100+ epochs.
# Train a self-supervised model using Lightly Train
model_path = geoai.lightly_train_model(
data_dir=f"{data_dir}/patches_0/images", # Use patches from first image
output_dir=output_dir,
model="torchvision/resnet50", # Base architecture
method="simclr", # Use SimCLR for CNN models like ResNet
epochs=10, # Small number for demo
batch_size=32, # Adjust based on your GPU
optim_args={"lr": 1e-4}, # Pass learning rate through optim_args
overwrite=True,
)
print(f"\nModel training completed! Pretrained model saved to: {model_path}")
Load pretrained model¶
Once training is complete, we can load the pretrained model for use in downstream tasks.
# Load the pretrained model
pretrained_model = geoai.load_lightly_pretrained_model(
model_path=model_path, model_architecture="torchvision/resnet50"
)
print(f"Loaded pretrained model: {type(pretrained_model)}")
print(f"Model parameters: {sum(p.numel() for p in pretrained_model.parameters()):,}")
# The model is now ready for fine-tuning on your specific task
print("\nModel is ready for fine-tuning on downstream tasks!")
Generate embeddings¶
We can also use the pretrained model checkpoint to generate embeddings for our images. These embeddings capture rich representations learned through self-supervised training.
Note: We need to use the checkpoint file (.ckpt) for embedding generation, not the exported model (.pt).
# Generate embeddings for our images
# Note: We use the checkpoint file (.ckpt) for embedding generation
checkpoint_path = os.path.join(output_dir, "checkpoints", "last.ckpt")
embeddings_path = geoai.lightly_embed_images(
data_dir=f"{data_dir}/patches_0/images",
model_path=checkpoint_path,
output_path=f"{embeddings_dir}/image_embeddings.pt",
batch_size=32,
overwrite=True,
)
print(f"Embeddings saved to: {embeddings_path}")
Analyze embeddings¶
Let's load and analyze the generated embeddings to understand what our model has learned.
# Load and analyze the embeddings
if os.path.exists(embeddings_path):
# Load PyTorch tensor embeddings
embeddings = torch.load(embeddings_path)
# Convert to numpy for analysis if it's a tensor
if isinstance(embeddings, torch.Tensor):
embeddings_np = embeddings.cpu().numpy()["embeddings"]
else:
embeddings_np = embeddings["embeddings"]
print(f"Embeddings shape: {embeddings_np.shape}")
print(f"Embedding dimension: {embeddings_np.shape[1]}")
print(f"Number of images embedded: {embeddings_np.shape[0]}")
# Basic statistics
print(f"\nEmbedding statistics:")
print(f"Mean: {embeddings_np.mean():.4f}")
print(f"Std: {embeddings_np.std():.4f}")
print(f"Min: {embeddings_np.min():.4f}")
print(f"Max: {embeddings_np.max():.4f}")
else:
print(
"Embeddings file not found. This might be due to the embedding generation process."
)
Visualize embeddings with t-SNE¶
Let's use t-SNE to visualize the high-dimensional embeddings in 2D space. This helps us understand how well the model groups similar images.
# Visualize embeddings using t-SNE
if os.path.exists(embeddings_path):
try:
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
# Apply t-SNE for dimensionality reduction
tsne = TSNE(
n_components=2, random_state=42, perplexity=min(30, len(embeddings_np) - 1)
)
embeddings_2d = tsne.fit_transform(embeddings_np)
# Create visualization
plt.figure(figsize=(10, 8))
scatter = plt.scatter(
embeddings_2d[:, 0],
embeddings_2d[:, 1],
c=range(len(embeddings_2d)),
cmap="viridis",
alpha=0.7,
)
plt.colorbar(scatter)
plt.title("t-SNE Visualization of Self-Supervised Embeddings")
plt.xlabel("t-SNE Component 1")
plt.ylabel("t-SNE Component 2")
plt.grid(True, alpha=0.3)
plt.show()
print("t-SNE visualization shows how the model groups similar image patches.")
print(
"Clusters indicate that the model has learned meaningful representations!"
)
except ImportError:
print("scikit-learn not available for t-SNE visualization.")
print("Install with: pip install scikit-learn")
else:
print("Embeddings not available for visualization.")
Next steps¶
Now that you have a pretrained self-supervised model, here are some ways you can use it:
1. Fine-tuning for specific tasks¶
# Load your pretrained model
model = geoai.load_lightly_pretrained_model(
model_path="path/to/your/model.pt",
model_architecture="torchvision/resnet50"
)
# Replace the final layer for your specific task
# For example, for binary classification:
import torch.nn as nn
model.fc = nn.Linear(model.fc.in_features, 2) # 2 classes
# Fine-tune with your labeled data using standard PyTorch training
2. Feature extraction¶
# Use the model as a feature extractor
model.eval()
# Remove the final classification layer
feature_extractor = nn.Sequential(*list(model.children())[:-1])
# Extract features for any new images
with torch.no_grad():
features = feature_extractor(your_image_tensor)
3. Similarity search¶
# Use embeddings for finding similar images
from sklearn.metrics.pairwise import cosine_similarity
# Find most similar images to a query image
similarities = cosine_similarity([query_embedding], embeddings)
most_similar_indices = similarities.argsort()[0][-5:] # Top 5 similar
4. Different architectures and methods¶
Try different combinations for your specific use case:
For CNN models (ResNet, EfficientNet):
# SimCLR method (recommended for CNNs)
model_path = geoai.lightly_train_model(
data_dir="path/to/images",
output_dir="output",
model="torchvision/resnet50",
method="simclr",
epochs=100
)
# Or DINO method (also works with CNNs)
model_path = geoai.lightly_train_model(
data_dir="path/to/images",
output_dir="output",
model="timm/efficientnet_b0",
method="dino",
epochs=100
)
For Vision Transformer (ViT) models:
# DINOv2 (recommended for ViTs, excellent for geospatial imagery)
model_path = geoai.lightly_train_model(
data_dir="path/to/images",
output_dir="output",
model="timm/vit_base_patch16_224",
method="dinov2",
epochs=100
)
Model architectures:
"torchvision/resnet50"
- Good general purpose CNN"torchvision/resnet101"
- Larger CNN for more complex features"timm/efficientnet_b0"
- Efficient CNN architecture"timm/vit_base_patch16_224"
- Vision Transformer (use with dinov2)
Self-supervised methods:
"simclr"
- Contrastive learning, works with CNNs"dino"
- Works with both CNNs and ViTs"dinov2"
- Advanced method, requires ViT models"dinov2_distillation"
- Enhanced DINOv2, requires ViT models
Cleanup (optional)¶
Remove temporary files if needed:
# Uncomment to clean up temporary files
# import shutil
# shutil.rmtree(data_dir, ignore_errors=True)
# shutil.rmtree(output_dir, ignore_errors=True)
# shutil.rmtree(embeddings_dir, ignore_errors=True)
# print("Temporary files cleaned up.")