Google Satellite Embedding with TorchGeo¶
Overview¶
The AlphaEarth Foundations Satellite Embedding dataset, produced by Google and Google DeepMind, provides pre-computed 64-dimensional embedding vectors at 10-meter resolution. Each pixel encodes information from optical, radar, LiDAR, and other Earth observation sources into a unit-length vector.
Key characteristics:
- Resolution: 10 m per pixel
- Dimensions: 64 (bands A00-A63)
- Temporal: Annual composites from 2018-2024
- Coverage: Global
- License: CC-BY-4.0
This notebook demonstrates how to:
- Download embedding data from Source Cooperative
- Load with TorchGeo's
GoogleSatelliteEmbeddingviageoai - Visualize embeddings as PCA-based RGB images
- Cluster embeddings to discover land cover patterns
- Search for similar pixels using cosine similarity
- Detect change by comparing embeddings across years
References:
Install Package¶
Uncomment the command below if needed.
# %pip install geoai-py scikit-learn
Import Libraries¶
import matplotlib.pyplot as plt
import numpy as np
import rasterio
import geoai
Dataset Info¶
View metadata about the Google Satellite Embedding dataset from the geoai registry.
info = geoai.get_embedding_info("google_satellite")
for key, value in info.items():
print(f"{key}: {value}")
Download Embedding Data¶
Download embeddings for a small region near Paradise, CA for two years (2018 and 2024). The 2018 Camp Fire destroyed much of this area, so comparing pre-fire (2018) and post-rebuilding (2024) embeddings should reveal significant change.
The download function fetches data from Source Cooperative using windowed reads from Cloud-Optimized GeoTIFFs, so only the requested region is transferred.
bbox = (-121.65, 39.73, -121.55, 39.80)
output_dir = "aef_data"
files = geoai.download_google_satellite_embedding(
bbox=bbox,
output_dir=output_dir,
years=[2018, 2024],
crs=None,
)
print(f"Downloaded {len(files)} file(s): {files}")
Inspect Downloaded Data¶
Check the properties of the downloaded GeoTIFF files.
for f in files:
with rasterio.open(f) as src:
print(f"File: {f}")
print(f" Shape: {src.count} bands x {src.height}H x {src.width}W")
print(f" CRS: {src.crs}")
print(f" Bounds: {src.bounds}")
print(f" Resolution: {src.res}")
print(f" Dtype: {src.dtypes[0]}")
print()
Load with TorchGeo¶
Load the downloaded data using TorchGeo's GoogleSatelliteEmbedding dataset class via geoai.load_embedding_dataset().
ds = geoai.load_embedding_dataset("google_satellite", paths=output_dir)
print(f"Dataset type: {type(ds).__name__}")
print(f"CRS: {ds.crs}")
print(f"Resolution: {ds.res}")
print(f"Bounds: {ds.bounds}")
Extract Pixel Embeddings¶
Use extract_pixel_embeddings() to sample patches from the dataset and flatten the pixels into an (N, 64) array suitable for analysis.
data = geoai.extract_pixel_embeddings(ds, num_samples=20, size=256, flatten=True)
embeddings = data["embeddings"]
print(f"Embeddings shape: {embeddings.shape}")
print(f"Value range: [{embeddings.min():.4f}, {embeddings.max():.4f}]")
Visualize Embeddings as RGB¶
Use PCA to project the 64-band embedding raster into 3 principal components for RGB visualization. Each sample patch is visualized individually.
# Get a single sample for visualization
from torchgeo.samplers import RandomGeoSampler
sampler = RandomGeoSampler(ds, size=512, length=1)
query = next(iter(sampler))
sample = ds[query]
print(f"Sample image shape: {sample['image'].shape}")
fig = geoai.plot_embedding_raster(
sample["image"],
title="Google Satellite Embedding (PCA RGB)",
)
plt.show()
Interactive Map: PCA RGB¶
Save PCA-projected embeddings as a 3-band GeoTIFF and display on an interactive map with a satellite basemap.
import os
import leafmap
from sklearn.decomposition import PCA
# Save derived products to a separate directory outside of output_dir
# to avoid interfering with TorchGeo's recursive directory scanning
viz_dir = "aef_viz"
os.makedirs(viz_dir, exist_ok=True)
# Create PCA RGB GeoTIFFs for both years
pca = PCA(n_components=3)
pca_files = []
for f in files:
with rasterio.open(f) as src:
data = src.read() # (64, H, W)
h, w = data.shape[1], data.shape[2]
pixels = data.reshape(64, -1).T # (H*W, 64)
mask = ~np.isnan(pixels).any(axis=1)
rgb = np.zeros((pixels.shape[0], 3))
if mask.any():
rgb[mask] = pca.fit_transform(pixels[mask])
rgb -= rgb[mask].min(axis=0)
maxv = rgb[mask].max(axis=0)
maxv[maxv == 0] = 1
rgb /= maxv
rgb = rgb.reshape(h, w, 3).clip(0, 1)
# Save as 3-band uint8 GeoTIFF
basename = os.path.basename(f).replace(".tif", "_pca_rgb.tif")
pca_output = os.path.join(viz_dir, basename)
with rasterio.open(
pca_output,
"w",
driver="GTiff",
height=h,
width=w,
count=3,
dtype="uint8",
crs=src.crs,
transform=src.transform,
compress="lzw",
) as dst:
for b in range(3):
dst.write((rgb[:, :, b] * 255).astype(np.uint8), b + 1)
pca_files.append(pca_output)
print(f"PCA RGB files: {pca_files}")
# Display PCA RGB on an interactive map
m = leafmap.Map()
m.add_basemap("Esri.WorldImagery")
m.add_raster(pca_files[0], layer_name="Embeddings 2018 (PCA RGB)")
m.add_raster(pca_files[1], layer_name="Embeddings 2024 (PCA RGB)")
m
Cluster Embeddings¶
Use K-Means clustering to discover spatial patterns in the embeddings without any labels. Each cluster may correspond to a distinct land cover type.
# Remove NaN pixels before clustering
valid_mask = ~np.isnan(embeddings).any(axis=1)
valid_embeddings = embeddings[valid_mask]
print(f"Valid pixels: {valid_embeddings.shape[0]} / {embeddings.shape[0]}")
result = geoai.cluster_embeddings(valid_embeddings, n_clusters=8, method="kmeans")
cluster_labels = result["labels"]
print(f"Number of clusters: {result['n_clusters']}")
print(f"Cluster sizes: {np.bincount(cluster_labels)}")
# Visualize clusters in PCA space
fig = geoai.visualize_embeddings(
valid_embeddings,
labels=cluster_labels,
method="pca",
figsize=(10, 8),
s=1,
alpha=0.3,
title="K-Means Clusters of Satellite Embeddings (PCA)",
)
plt.show()
# Visualize clusters as a spatial map from a single patch
sampler = RandomGeoSampler(ds, size=512, length=1)
query = next(iter(sampler))
sample = ds[query]
image = sample["image"].numpy() # (64, H, W)
c, h, w = image.shape
pixels = image.reshape(c, -1).T # (H*W, 64)
# Remove NaN pixels
pixel_valid = ~np.isnan(pixels).any(axis=1)
valid_px = pixels[pixel_valid]
# Predict clusters using the fitted model
pred_labels = result["model"].predict(valid_px)
# Reconstruct spatial map
cluster_map = np.full(h * w, -1, dtype=int)
cluster_map[pixel_valid] = pred_labels
cluster_map = cluster_map.reshape(h, w)
fig, ax = plt.subplots(figsize=(8, 8))
im = ax.imshow(cluster_map, cmap="tab10", interpolation="nearest")
ax.set_title("Embedding Cluster Map")
ax.axis("off")
plt.colorbar(im, ax=ax, shrink=0.7, label="Cluster")
plt.tight_layout()
plt.show()
Interactive Map: Cluster Map¶
Save the cluster map as a georeferenced GeoTIFF and display it on an interactive map.
# Save cluster map for the full 2024 embedding as a GeoTIFF
with rasterio.open(files[1]) as src:
emb_data = src.read() # (64, H, W)
emb_h, emb_w = emb_data.shape[1], emb_data.shape[2]
# Match dtype to the KMeans model's cluster centers
target_dtype = result["model"].cluster_centers_.dtype
emb_pixels = emb_data.reshape(64, -1).T.astype(target_dtype)
emb_valid = ~np.isnan(emb_pixels).any(axis=1)
emb_valid_px = emb_pixels[emb_valid]
# Predict clusters
full_labels = result["model"].predict(emb_valid_px)
# Reconstruct spatial map (use 255 as nodata for uint8)
full_cluster_map = np.full(emb_h * emb_w, 255, dtype=np.uint8)
full_cluster_map[emb_valid] = full_labels.astype(np.uint8)
full_cluster_map = full_cluster_map.reshape(emb_h, emb_w)
cluster_output = os.path.join(viz_dir, "cluster_map_2024.tif")
with rasterio.open(
cluster_output,
"w",
driver="GTiff",
height=emb_h,
width=emb_w,
count=1,
dtype="uint8",
crs=src.crs,
transform=src.transform,
compress="lzw",
nodata=255,
) as dst:
dst.write(full_cluster_map, 1)
print(f"Cluster map saved to {cluster_output}")
# Display cluster map on an interactive map
m = leafmap.Map()
m.add_basemap("Esri.WorldImagery")
m.add_raster(
cluster_output, cmap="tab10", nodata=255, opacity=0.7, layer_name="Clusters"
)
m
Similarity Search¶
Find pixels with the most similar embedding vectors to a query pixel using cosine similarity.
# Use the center pixel as a query
center_idx = len(valid_embeddings) // 2
query_embedding = valid_embeddings[center_idx]
results = geoai.embedding_similarity(
query=query_embedding,
embeddings=valid_embeddings,
metric="cosine",
top_k=10,
)
print("Top 10 most similar pixels:")
for rank, (idx, score) in enumerate(
zip(results["indices"], results["scores"]), start=1
):
print(f" {rank}. Index {idx}: similarity={score:.4f}")
Change Detection¶
Compare embeddings from two years to detect changes on the ground. We read a matching patch from each year and compute the cosine similarity between corresponding pixels. Low similarity values indicate change.
# Read the two downloaded files directly
with rasterio.open(files[0]) as src1, rasterio.open(files[1]) as src2:
data1 = src1.read() # (64, H, W)
data2 = src2.read()
# Use the smaller common extent
min_h = min(data1.shape[1], data2.shape[1])
min_w = min(data1.shape[2], data2.shape[2])
data1 = data1[:, :min_h, :min_w]
data2 = data2[:, :min_h, :min_w]
print(f"Year 1 shape: {data1.shape}")
print(f"Year 2 shape: {data2.shape}")
# Flatten to (N_pixels, 64)
emb1 = data1.reshape(64, -1).T
emb2 = data2.reshape(64, -1).T
# Remove pixels where either year has NaN
valid = ~(np.isnan(emb1).any(axis=1) | np.isnan(emb2).any(axis=1))
emb1_valid = emb1[valid]
emb2_valid = emb2[valid]
print(f"Valid pixel pairs: {emb1_valid.shape[0]}")
# Compute cosine similarity
similarity = geoai.compare_embeddings(emb1_valid, emb2_valid, metric="cosine")
print(f"Mean similarity: {similarity.mean():.4f}")
print(f"Min similarity: {similarity.min():.4f}")
print(f"Max similarity: {similarity.max():.4f}")
# Visualize the similarity distribution
fig, ax = plt.subplots(figsize=(10, 4))
ax.hist(similarity, bins=100, edgecolor="black", alpha=0.7, color="steelblue")
ax.axvline(
similarity.mean(),
color="red",
linestyle="--",
linewidth=2,
label=f"Mean: {similarity.mean():.3f}",
)
ax.set_xlabel("Cosine Similarity")
ax.set_ylabel("Pixel Count")
ax.set_title("Embedding Similarity Between 2018 and 2024")
ax.legend()
plt.tight_layout()
plt.show()
# Create a spatial change map
change_map = np.full(min_h * min_w, np.nan)
change_map[valid] = similarity
change_map = change_map.reshape(min_h, min_w)
fig, axes = plt.subplots(1, 3, figsize=(18, 6))
# Year 1 PCA RGB
from sklearn.decomposition import PCA
pca = PCA(n_components=3)
for ax_idx, (data, year) in enumerate([(data1, "2018"), (data2, "2024")]):
pixels = data.reshape(64, -1).T
mask = ~np.isnan(pixels).any(axis=1)
rgb = np.zeros((pixels.shape[0], 3))
if mask.any():
rgb[mask] = pca.fit_transform(pixels[mask])
rgb -= rgb[mask].min(axis=0)
maxv = rgb[mask].max(axis=0)
maxv[maxv == 0] = 1
rgb /= maxv
rgb = rgb.reshape(min_h, min_w, 3).clip(0, 1)
axes[ax_idx].imshow(rgb)
axes[ax_idx].set_title(f"Embeddings {year} (PCA RGB)")
axes[ax_idx].axis("off")
# Change map
im = axes[2].imshow(change_map, cmap="RdYlGn", vmin=0, vmax=1, interpolation="nearest")
axes[2].set_title("Cosine Similarity (Change Map)")
axes[2].axis("off")
plt.colorbar(im, ax=axes[2], shrink=0.7, label="Similarity")
plt.tight_layout()
plt.show()
Save Embeddings as GeoTIFF¶
Export the change map or embedding data as a georeferenced GeoTIFF for use in GIS software.
# Save the change map as a single-band GeoTIFF
with rasterio.open(files[0]) as src:
bounds = src.bounds
transform = src.transform
file_crs = src.crs
change_output = os.path.join(viz_dir, "change_map_2018_2024.tif")
with rasterio.open(
change_output,
"w",
driver="GTiff",
height=min_h,
width=min_w,
count=1,
dtype="float64",
crs=file_crs,
transform=transform,
compress="lzw",
nodata=np.nan,
) as dst:
dst.write(change_map, 1)
dst.set_band_description(1, "cosine_similarity")
print(f"Change map saved to {change_output}")
Interactive Map: Change Detection¶
Visualize the change detection results on interactive maps. Use a split map to compare PCA-projected embeddings from 2018 and 2024 side by side, and overlay the change map on satellite imagery.
# Split map comparing 2018 vs 2024 PCA RGB embeddings
m = leafmap.Map()
m.split_map(left_layer=pca_files[0], right_layer=pca_files[1])
m
# Display change map overlaid on satellite imagery
m = leafmap.Map()
m.add_basemap("Esri.WorldImagery")
m.add_raster(change_output, cmap="RdYlGn", layer_name="Change Map (Cosine Similarity)")
m
Summary¶
This notebook demonstrated the end-to-end workflow for working with Google/AlphaEarth Satellite Embeddings using geoai and TorchGeo:
- 64-D embedding vectors at 10 m resolution encode multi-source satellite observations
- No GPU required — embeddings are pre-computed and ready for analysis
- Windowed download fetches only the region of interest from Cloud-Optimized GeoTIFFs
- Unsupervised clustering reveals distinct land cover patterns
- Cosine similarity between years enables change detection
- Interactive maps via leafmap for exploring embeddings, clusters, and change maps
- Data is freely available from Source Cooperative under CC-BY-4.0