Sentinel-2 Super-Resolution with Latent Diffusion¶
This notebook demonstrates how to use the OpenSR latent diffusion model integrated in GeoAI to perform 4x super-resolution on Sentinel-2 imagery, enhancing the spatial resolution from 10m to 2.5m.
The method is based on the paper Trustworthy Super-Resolution of Multispectral Sentinel-2 Imagery with Latent Diffusion and operates on four bands (Red, Green, Blue, NIR).
GPU recommended: While the model can run on CPU, a CUDA-enabled GPU will significantly speed up inference.
Installation¶
Uncomment the following line to install the required packages.
# %pip install -U "geoai-py[sr]"
Import Libraries¶
import geoai
import numpy as np
import rasterio as rio
from matplotlib import pyplot as plt
Download Sample Data¶
We use a Sentinel-2 L2A subset over Knoxville, TN. The image contains four 10m bands (Red, Green, Blue, NIR) stored as uint16 BOA reflectance values (range 0--10,000).
url = "https://data.source.coop/opengeos/geoai/S2C-MSIL2A-20250920T162001-Knoxville.tif"
s2_path = geoai.download_file(url)
Inspect Input Data¶
Let's check the image dimensions, coordinate reference system, and pixel resolution.
with rio.open(s2_path) as src:
print(f"Bands: {src.count}")
print(f"Size: {src.width} x {src.height}")
print(f"CRS: {src.crs}")
print(f"Resolution: {src.res[0]:.2f} m")
print(f"Dtype: {src.dtypes[0]}")
Visualize Input RGB Composite¶
Display a true-color composite (Red/Green/Blue) with a percentile-based contrast stretch.
with rio.open(s2_path) as src:
rgb = src.read([1, 2, 3]).astype(np.float32)
# Percentile-based contrast stretch
for i in range(3):
band = rgb[i]
p2, p98 = np.percentile(band, (2, 98))
rgb[i] = (band - p2) / (p98 - p2)
rgb = np.clip(rgb, 0, 1)
fig, ax = plt.subplots(figsize=(12, 7))
ax.imshow(rgb.transpose(1, 2, 0))
ax.set_title("Sentinel-2 RGB Composite (10 m)")
ax.set_axis_off()
plt.tight_layout()
plt.show()
Run Super-Resolution on a 128x128 Crop¶
The LDSR-S2 model processes 128x128 pixel patches. We extract a small window from the scene for a fast demonstration. The window parameter specifies (row_offset, col_offset, height, width).
The model:
- Encodes the low-resolution patch into a latent space
- Runs a denoising diffusion process for the specified number of sampling steps
- Decodes the latent back to image space at 4x resolution (128x128 becomes 512x512)
sr_output = "sr_output.tif"
sr_image, _ = geoai.super_resolution(
input_lr_path=s2_path,
output_sr_path=sr_output,
rgb_nir_bands=[1, 2, 3, 4],
window=(700, 1300, 128, 128),
sampling_steps=100,
)
print(f"Input shape: (4, 128, 128) at 10 m")
print(f"Output shape: {sr_image.shape} at 2.5 m")
Compare Low-Resolution vs Super-Resolution¶
Use plot_sr_comparison to display a side-by-side RGB comparison. The low-resolution input is read from the original file (windowed to the same region) and the super-resolution output from the saved GeoTIFF.
geoai.plot_sr_comparison(s2_path, sr_output, bands=[1, 2, 3])
plt.show()
Verify that the output GeoTIFF has the correct spatial reference and 2.5 m pixel size.
with rio.open(sr_output) as src:
print(f"SR Bands: {src.count}")
print(f"SR Size: {src.width} x {src.height}")
print(f"SR CRS: {src.crs}")
print(f"SR Resolution: {src.res[0]:.2f} m")
Compute Uncertainty Map¶
The model can estimate per-pixel uncertainty by running multiple stochastic forward passes with different random seeds and computing the standard deviation across variations. Higher uncertainty indicates regions where the model is less confident about the super-resolved output.
sr_unc_output = "sr_with_uncertainty.tif"
unc_output = "uncertainty.tif"
sr_image2, uncertainty = geoai.super_resolution(
input_lr_path=s2_path,
output_sr_path=sr_unc_output,
output_uncertainty_path=unc_output,
rgb_nir_bands=[1, 2, 3, 4],
window=(700, 1300, 128, 128),
compute_uncertainty=True,
n_variations=5,
sampling_steps=100,
)
Visualize Uncertainty¶
Display the uncertainty map. Red/yellow regions indicate higher uncertainty, green regions indicate higher confidence.
geoai.plot_sr_uncertainty(unc_output)
plt.show()
Process a Larger Region with Patched Inference¶
For areas larger than 128x128 pixels, the module automatically tiles the input into overlapping patches, runs super-resolution on each, and stitches the results back together with linear blending to avoid tiling artifacts.
sr_large = "sr_large.tif"
sr_large_img, _ = geoai.super_resolution(
input_lr_path=s2_path,
output_sr_path=sr_large,
rgb_nir_bands=[1, 2, 3, 4],
window=(700, 1300, 256, 256),
patch_size=128,
overlap=16,
sampling_steps=100,
)
print(f"Input shape: (4, 256, 256) at 10 m")
print(f"Output shape: {sr_large_img.shape} at 2.5 m")
Visualize Patched Super-Resolution Result¶
Compare the low-resolution input with the stitched super-resolution output for the larger region.
geoai.plot_sr_comparison(s2_path, sr_large, bands=[1, 2, 3])
plt.show()
geoai.create_split_map(
left_layer=sr_large, right_layer="Esri.WorldImagery", left_args={"vmax": 0.3}
)