Geospatial Embedding Datasets with TorchGeo¶
Overview¶
TorchGeo v0.9.0 introduces Earth Embeddings — pre-computed representations from geospatial foundation models that encode satellite imagery into compact vector representations. These embeddings enable rapid analysis without requiring GPU compute for running foundation models.
This notebook demonstrates how to use the geoai embeddings module to:
- Browse available embedding datasets
- Load patch-based embedding datasets (Clay Foundation Model)
- Visualize high-dimensional embeddings using PCA
- Cluster embeddings to discover spatial patterns
- Search for similar locations using cosine similarity
- Classify land use types using lightweight classifiers on embeddings
Embedding Dataset Types¶
| Type | Format | Examples | Use Case |
|---|---|---|---|
| Patch-based | GeoParquet | Clay, Major TOM, Earth Index | Global-scale analysis, classification |
| Pixel-based | GeoTIFF | Google Satellite, Tessera, Presto | High-resolution mapping, change detection |
Install Package¶
Uncomment the command below if needed.
# %pip install geoai-py scikit-learn
Import Libraries¶
import geopandas as gpd
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from huggingface_hub import HfApi, hf_hub_download
import geoai
1. Browse Available Embedding Datasets¶
The geoai package provides a registry of all embedding datasets available in TorchGeo v0.9.0. Use list_embedding_datasets() to see what's available.
# List all embedding datasets
df = geoai.list_embedding_datasets(verbose=False)
df
# Filter by type: patch-based datasets
geoai.list_embedding_datasets(kind="patch", verbose=False)
# Filter by type: pixel-based datasets
geoai.list_embedding_datasets(kind="pixel", verbose=False)
# Get detailed info about a specific dataset
info = geoai.get_embedding_info("google_satellite")
for key, value in info.items():
print(f"{key}: {value}")
2. Download Clay Embeddings (SF Bay Area)¶
We'll use Clay Foundation Model embeddings for the San Francisco Bay Area from HuggingFace. This dataset contains 768-dimensional embeddings computed from NAIP aerial imagery across 20 tiles, along with labeled locations for baseball fields (class 0) and marinas (class 1).
Download all embedding tiles¶
repo_id = "made-with-clay/classify-embeddings-sf-baseball-marinas"
# List all embedding GeoParquet files
api = HfApi()
embedding_files = [
f.path
for f in api.list_repo_tree(repo_id, repo_type="dataset")
if f.path.endswith(".gpq")
]
print(f"Found {len(embedding_files)} embedding tiles")
# Download all tiles and concatenate into a single GeoDataFrame
all_gdfs = []
for f in embedding_files:
path = hf_hub_download(repo_id, f, repo_type="dataset")
gdf = gpd.read_parquet(path)
all_gdfs.append(gdf)
embeddings_gdf = pd.concat(all_gdfs, ignore_index=True)
embeddings_gdf = gpd.GeoDataFrame(
embeddings_gdf, geometry="geometry", crs=all_gdfs[0].crs
)
print(f"Combined: {len(embeddings_gdf)} patches")
print(f"Bounds: {embeddings_gdf.total_bounds}")
print(f"Embedding dimension: {len(embeddings_gdf.iloc[0]['embeddings'])}")
# Download the labeled locations (baseball fields and marinas)
labels_file = hf_hub_download(repo_id, "baseball.geojson", repo_type="dataset")
labels_gdf = gpd.read_file(labels_file)
print(f"Labeled locations: {len(labels_gdf)}")
print(f"Class distribution:")
print(labels_gdf["class"].value_counts())
Extract embedding vectors¶
Convert the embedding column to a NumPy array and extract coordinates for analysis.
# Extract embeddings, coordinates from the GeoParquet
embeddings = np.stack(embeddings_gdf["embeddings"].values)
centroids = embeddings_gdf.geometry.centroid
coords_x = centroids.x.values
coords_y = centroids.y.values
print(f"Embeddings shape: {embeddings.shape}")
print(f"X range: [{coords_x.min():.4f}, {coords_x.max():.4f}]")
print(f"Y range: [{coords_y.min():.4f}, {coords_y.max():.4f}]")
# Plot a few embedding vectors to see their patterns
fig, axes = plt.subplots(1, 3, figsize=(15, 3))
for i, ax in enumerate(axes):
idx = i * (len(embeddings) // 3)
ax.plot(embeddings[idx], linewidth=0.5)
ax.set_title(f"Patch {idx} ({coords_y[idx]:.3f}°N, {coords_x[idx]:.3f}°W)")
ax.set_xlabel("Dimension")
ax.set_ylabel("Value")
plt.tight_layout()
plt.show()
PCA projection of all embeddings¶
# Visualize the embedding space using PCA
fig = geoai.visualize_embeddings(
embeddings,
method="pca",
figsize=(8, 8),
s=3,
alpha=0.4,
title="PCA of Clay Embeddings (SF Bay Area)",
)
plt.show()
4. Cluster Embeddings¶
Use unsupervised clustering to discover patterns in the embeddings without any labels.
# Cluster the embeddings into groups
result = geoai.cluster_embeddings(embeddings, n_clusters=8, method="kmeans")
cluster_labels = result["labels"]
print(f"Number of clusters: {result['n_clusters']}")
print(f"Cluster sizes: {np.bincount(cluster_labels)}")
# Visualize clusters in PCA space
fig = geoai.visualize_embeddings(
embeddings,
labels=cluster_labels,
method="pca",
figsize=(10, 8),
s=5,
alpha=0.5,
title="K-Means Clusters of Clay Embeddings",
)
plt.show()
# Map clusters geographically
fig, ax = plt.subplots(figsize=(10, 8))
scatter = ax.scatter(
coords_x,
coords_y,
c=cluster_labels,
cmap="tab10",
s=3,
alpha=0.6,
)
plt.colorbar(scatter, ax=ax, label="Cluster")
ax.set_xlabel("Longitude")
ax.set_ylabel("Latitude")
ax.set_title("Geographic Distribution of Embedding Clusters")
ax.set_aspect("equal")
plt.tight_layout()
plt.show()
5. Similarity Search¶
Find the most similar locations to a query embedding using cosine similarity.
# Pick a query embedding (first patch)
query_idx = 0
query = embeddings[query_idx]
print(f"Query location: ({coords_y[query_idx]:.4f}°N, {coords_x[query_idx]:.4f}°W)")
# Find top-10 most similar locations
results = geoai.embedding_similarity(
query=query, embeddings=embeddings, metric="cosine", top_k=10
)
print("\nTop 10 most similar locations:")
for rank, (idx, score) in enumerate(
zip(results["indices"], results["scores"]), start=1
):
print(
f" {rank}. Index {idx}: similarity={score:.4f}, "
f"location=({coords_y[idx]:.4f}°N, {coords_x[idx]:.4f}°W)"
)
# Visualize the query and its nearest neighbors on a map
fig, ax = plt.subplots(figsize=(10, 8))
# Background: all embeddings in gray
ax.scatter(coords_x, coords_y, c="lightgray", s=1, alpha=0.3)
# Highlight nearest neighbors
nn_indices = results["indices"]
ax.scatter(
coords_x[nn_indices],
coords_y[nn_indices],
c="blue",
s=50,
marker="o",
label="Nearest Neighbors",
edgecolors="black",
linewidths=0.5,
)
# Highlight the query point
ax.scatter(
coords_x[query_idx],
coords_y[query_idx],
c="red",
s=100,
marker="*",
label="Query",
edgecolors="black",
linewidths=0.5,
zorder=5,
)
ax.set_xlabel("Longitude")
ax.set_ylabel("Latitude")
ax.set_title("Similarity Search: Query and Nearest Neighbors")
ax.legend()
ax.set_aspect("equal")
plt.tight_layout()
plt.show()
6. Classification with Embeddings¶
Train a lightweight k-NN classifier on the Clay embeddings using labeled data. The dataset includes labeled locations for baseball fields (class 0) and marinas (class 1) in the San Francisco Bay Area.
Prepare training data¶
We match labeled points to their nearest embedding patches using a spatial join.
# Ensure both GeoDataFrames use the same CRS
if labels_gdf.crs != embeddings_gdf.crs:
labels_gdf = labels_gdf.to_crs(embeddings_gdf.crs)
# Spatial join: find which embedding patch each labeled point falls within
joined = gpd.sjoin(labels_gdf, embeddings_gdf, how="inner", predicate="within")
print(f"Matched {len(joined)} labeled points to embedding patches")
print(f"Class distribution: {joined['class'].value_counts().to_dict()}")
# Extract embeddings and labels for matched points
labeled_embeddings = np.stack(
[embeddings_gdf.iloc[idx]["embeddings"] for idx in joined["index_right"]]
)
class_labels = joined["class"].values
print(f"Labeled embeddings shape: {labeled_embeddings.shape}")
print(f"Labels shape: {class_labels.shape}")
# Split into train/validation sets
from sklearn.model_selection import train_test_split
X_train, X_val, y_train, y_val = train_test_split(
labeled_embeddings,
class_labels,
test_size=0.3,
random_state=42,
stratify=class_labels,
)
print(f"Train: {X_train.shape[0]} samples")
print(f"Val: {X_val.shape[0]} samples")
Train a k-NN classifier¶
label_names = ["Baseball Field", "Marina"]
# Train using geoai's convenience function
result = geoai.train_embedding_classifier(
train_embeddings=X_train,
train_labels=y_train,
val_embeddings=X_val,
val_labels=y_val,
method="knn",
n_neighbors=5,
label_names=label_names,
)
print(f"\nTrain accuracy: {result['train_accuracy']:.2%}")
print(f"Val accuracy: {result['val_accuracy']:.2%}")
Compare different classifiers¶
# Try different classifiers
methods = ["knn", "random_forest", "logistic_regression"]
results_summary = []
for method in methods:
res = geoai.train_embedding_classifier(
train_embeddings=X_train,
train_labels=y_train,
val_embeddings=X_val,
val_labels=y_val,
method=method,
label_names=label_names,
verbose=False,
)
results_summary.append(
{
"Method": method,
"Train Acc": f"{res['train_accuracy']:.2%}",
"Val Acc": f"{res['val_accuracy']:.2%}",
}
)
pd.DataFrame(results_summary)
Visualize classified embeddings¶
# Visualize labeled embeddings in PCA space
fig = geoai.visualize_embeddings(
labeled_embeddings,
labels=class_labels,
label_names=label_names,
method="pca",
figsize=(8, 8),
s=30,
alpha=0.8,
title="PCA of Labeled Embeddings (Baseball vs Marina)",
)
plt.show()
7. Comparing Embeddings for Change Detection¶
Embedding vectors from different time periods can be compared to detect change. The compare_embeddings function computes element-wise similarity between two sets of embeddings.
Here we demonstrate the concept by comparing embeddings from different spatial patches.
# Compare first half vs second half of patches to simulate temporal comparison
n = len(embeddings)
half = n // 2
emb_a = embeddings[:half]
emb_b = embeddings[half : half + half] # same number of samples
similarity = geoai.compare_embeddings(emb_a, emb_b, metric="cosine")
fig, ax = plt.subplots(figsize=(10, 4))
ax.hist(similarity, bins=50, edgecolor="black", alpha=0.7)
ax.axvline(
similarity.mean(),
color="red",
linestyle="--",
label=f"Mean: {similarity.mean():.3f}",
)
ax.set_xlabel("Cosine Similarity")
ax.set_ylabel("Count")
ax.set_title("Embedding Similarity Distribution")
ax.legend()
plt.tight_layout()
plt.show()
8. Using TorchGeo Dataset Classes Directly¶
For more advanced usage, you can use the TorchGeo dataset classes directly through geoai.load_embedding_dataset(). This gives you access to all the TorchGeo features like transforms, sampling, and plotting.
Note: The TorchGeo ClayEmbeddings class expects a date or datetime column in the GeoParquet file. Some community-contributed embedding files may not include this column. In such cases, load the data with geopandas directly (as shown above).
# Load using geoai's unified interface
single_file = hf_hub_download(repo_id, embedding_files[0], repo_type="dataset")
ds = geoai.load_embedding_dataset("clay", root=single_file)
print(f"Dataset length: {len(ds)}")
print(f"Dataset type: {type(ds).__name__}")
# Access a sample - may fail if the file lacks a 'datetime' column
try:
sample = ds[0]
print(f"Sample keys: {list(sample.keys())}")
print(f"Embedding shape: {sample['embedding'].shape}")
print(f"Location: ({sample['y'].item():.4f}°N, {sample['x'].item():.4f}°W)")
fig = ds.plot(sample)
plt.show()
except KeyError as e:
print(
f"Note: This parquet file is missing the '{e.args[0]}' column "
f"expected by TorchGeo's ClayEmbeddings class."
)
print("For such files, use geopandas directly (as shown above).")
print("The TorchGeo class works best with official Clay data products.")
Summary¶
In this notebook, we demonstrated the geoai embeddings module which provides a unified interface to TorchGeo v0.9.0's embedding datasets. Key takeaways:
- 9 embedding datasets are available, spanning patch-based and pixel-based formats
- No GPU required for analysis — embeddings are pre-computed
- Lightweight classifiers (k-NN, Random Forest) work well on embeddings
- Unsupervised clustering reveals spatial patterns without labels
- Similarity search enables content-based spatial retrieval
- Change detection is possible by comparing embeddings across time periods
For pixel-based datasets (Google Satellite Embedding, Tessera, etc.), download GeoTIFF files and use geoai.load_embedding_dataset() with the paths parameter.