Skip to content

sr module

Super-resolution utilities using OpenSR latent diffusion models.

This module provides functions to perform super-resolution on multispectral GeoTIFF images using the latent diffusion models from the ESA OpenSR project:

1
GitHub: https://github.com/ESAOpenSR/opensr-model.git

load_image_tensor(image_path, device, bands, window=None, scale_factor=10000.0)

Load specified bands of a multispectral GeoTIFF as a PyTorch tensor.

The pixel values are divided by scale_factor to normalize them to the [0, 1] range expected by the OpenSR model.

Parameters:

Name Type Description Default
image_path str

Path to input GeoTIFF.

required
device str

Device to move the tensor to ('cpu' or 'cuda').

required
bands list[int]

List of 1-based band indices to read.

required
window tuple

Region of interest as (row_off, col_off, height, width). If None, the full image is read.

None
scale_factor float

Divisor to normalize pixel values to [0, 1]. Default is 10000.0 for Sentinel-2 L2A BOA reflectance.

10000.0

Returns:

Type Description
Tensor

Tuple[torch.Tensor, dict]: Tensor of shape (1, C, H, W) and rasterio

dict

profile adjusted for the window (if provided).

Raises:

Type Description
FileNotFoundError

If input image does not exist.

ValueError

If any band index is out of range.

Source code in geoai/tools/sr.py
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
def load_image_tensor(
    image_path: str,
    device: str,
    bands: list[int],
    window: Optional[tuple] = None,
    scale_factor: float = 10000.0,
) -> Tuple[torch.Tensor, dict]:
    """Load specified bands of a multispectral GeoTIFF as a PyTorch tensor.

    The pixel values are divided by ``scale_factor`` to normalize them to the
    [0, 1] range expected by the OpenSR model.

    Args:
        image_path (str): Path to input GeoTIFF.
        device (str): Device to move the tensor to ('cpu' or 'cuda').
        bands (list[int]): List of 1-based band indices to read.
        window (tuple, optional): Region of interest as
            ``(row_off, col_off, height, width)``. If None, the full image
            is read.
        scale_factor (float): Divisor to normalize pixel values to [0, 1].
            Default is 10000.0 for Sentinel-2 L2A BOA reflectance.

    Returns:
        Tuple[torch.Tensor, dict]: Tensor of shape (1, C, H, W) and rasterio
        profile adjusted for the window (if provided).

    Raises:
        FileNotFoundError: If input image does not exist.
        ValueError: If any band index is out of range.
    """
    if not os.path.exists(image_path):
        raise FileNotFoundError(f"Input image does not exist: {image_path}")

    with rasterio.open(image_path) as src:
        n_bands = src.count
        if min(bands) < 1 or max(bands) > n_bands:
            raise ValueError(
                f"Input image has {n_bands} bands, requested bands {bands} out of range."
            )

        rio_window = None
        if window is not None:
            if len(window) != 4:
                raise ValueError(
                    f"window must be a 4-tuple (row_off, col_off, height, width). "
                    f"Received {len(window)} elements."
                )
            row_off, col_off, win_h, win_w = window
            if row_off < 0 or col_off < 0 or win_h <= 0 or win_w <= 0:
                raise ValueError(
                    f"Window offsets must be >= 0 and dimensions must be > 0. "
                    f"Received row_off={row_off}, col_off={col_off}, "
                    f"height={win_h}, width={win_w}."
                )
            if row_off + win_h > src.height or col_off + win_w > src.width:
                raise ValueError(
                    f"Window (row_off={row_off}, col_off={col_off}, "
                    f"height={win_h}, width={win_w}) exceeds image "
                    f"dimensions ({src.height} x {src.width})."
                )
            rio_window = Window(col_off, row_off, win_w, win_h)

        image = src.read(bands, window=rio_window)  # shape: (C, H, W)
        profile = src.profile.copy()

        if rio_window is not None:
            profile["transform"] = src.window_transform(rio_window)
            profile["height"] = image.shape[1]
            profile["width"] = image.shape[2]

    image = image.astype(np.float32) / scale_factor
    image = np.nan_to_num(image, nan=0.0)
    tensor = torch.from_numpy(image).unsqueeze(0).to(device)
    return tensor, profile

plot_sr_comparison(lr_path, sr_path, bands=[1, 2, 3], lr_vmax=None, sr_vmax=None, figsize=(14, 7), **kwargs)

Plot a side-by-side comparison of low-resolution and super-resolution images.

Displays RGB composites of the LR input and SR output.

Parameters:

Name Type Description Default
lr_path str

Path to the low-resolution GeoTIFF.

required
sr_path str

Path to the super-resolution GeoTIFF.

required
bands list[int]

Three 1-based band indices for the RGB composite. Default is [1, 2, 3].

[1, 2, 3]
lr_vmax float

Maximum value for LR image contrast stretch. If None, the 98th percentile is used.

None
sr_vmax float

Maximum value for SR image contrast stretch. If None, the 98th percentile is used.

None
figsize tuple

Figure size. Default is (14, 7).

(14, 7)
**kwargs

Additional keyword arguments passed to matplotlib.pyplot.subplots.

{}

Returns:

Type Description

matplotlib.figure.Figure: The matplotlib figure object.

Source code in geoai/tools/sr.py
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
def plot_sr_comparison(
    lr_path: str,
    sr_path: str,
    bands: list[int] = [1, 2, 3],
    lr_vmax: Optional[float] = None,
    sr_vmax: Optional[float] = None,
    figsize: tuple = (14, 7),
    **kwargs,
):
    """Plot a side-by-side comparison of low-resolution and super-resolution
    images.

    Displays RGB composites of the LR input and SR output.

    Args:
        lr_path (str): Path to the low-resolution GeoTIFF.
        sr_path (str): Path to the super-resolution GeoTIFF.
        bands (list[int]): Three 1-based band indices for the RGB composite.
            Default is ``[1, 2, 3]``.
        lr_vmax (float, optional): Maximum value for LR image contrast
            stretch. If None, the 98th percentile is used.
        sr_vmax (float, optional): Maximum value for SR image contrast
            stretch. If None, the 98th percentile is used.
        figsize (tuple): Figure size. Default is ``(14, 7)``.
        **kwargs: Additional keyword arguments passed to
            ``matplotlib.pyplot.subplots``.

    Returns:
        matplotlib.figure.Figure: The matplotlib figure object.
    """
    import matplotlib.pyplot as plt

    with rasterio.open(sr_path) as sr_src:
        sr = sr_src.read(bands).astype(np.float32)
        sr_res = abs(sr_src.transform.a)
        sr_bounds = sr_src.bounds

    with rasterio.open(lr_path) as lr_src:
        lr_res = abs(lr_src.transform.a)
        # Read only the region matching the SR output extent
        lr_window = lr_src.window(*sr_bounds)
        lr_window = lr_window.round_offsets().round_lengths()
        lr_window = lr_window.intersection(Window(0, 0, lr_src.width, lr_src.height))
        lr = lr_src.read(bands, window=lr_window).astype(np.float32)

    def _stretch(img, vmax=None):
        out = np.zeros_like(img)
        for i in range(img.shape[0]):
            band = img[i]
            vmin = np.percentile(band[band > 0], 2) if (band > 0).any() else 0
            if vmax is None:
                vm = np.percentile(band[band > 0], 98) if (band > 0).any() else 1
            else:
                vm = vmax
            out[i] = (band - vmin) / (vm - vmin + 1e-10)
        return np.clip(out, 0, 1).transpose(1, 2, 0)

    lr_rgb = _stretch(lr, lr_vmax)
    sr_rgb = _stretch(sr, sr_vmax)

    fig, axes = plt.subplots(1, 2, figsize=figsize, **kwargs)

    axes[0].imshow(lr_rgb)
    axes[0].set_title(f"Low Resolution ({lr_res:.1f} m)")
    axes[0].set_xlabel("Column")
    axes[0].set_ylabel("Row")

    axes[1].imshow(sr_rgb)
    axes[1].set_title(f"Super Resolution ({sr_res:.2f} m)")
    axes[1].set_xlabel("Column")
    axes[1].set_ylabel("Row")

    plt.tight_layout()
    return fig

plot_sr_uncertainty(uncertainty_path, cmap='RdYlGn_r', normalize=True, figsize=(8, 8), **kwargs)

Plot the uncertainty map from super-resolution inference.

Parameters:

Name Type Description Default
uncertainty_path str

Path to the uncertainty GeoTIFF.

required
cmap str

Matplotlib colormap name. Default is 'RdYlGn_r'.

'RdYlGn_r'
normalize bool

Whether to normalize values to [0, 1]. Default is True.

True
figsize tuple

Figure size. Default is (8, 8).

(8, 8)
**kwargs

Additional keyword arguments passed to matplotlib.pyplot.subplots.

{}

Returns:

Type Description

matplotlib.figure.Figure: The matplotlib figure object.

Source code in geoai/tools/sr.py
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
def plot_sr_uncertainty(
    uncertainty_path: str,
    cmap: str = "RdYlGn_r",
    normalize: bool = True,
    figsize: tuple = (8, 8),
    **kwargs,
):
    """Plot the uncertainty map from super-resolution inference.

    Args:
        uncertainty_path (str): Path to the uncertainty GeoTIFF.
        cmap (str): Matplotlib colormap name. Default is ``'RdYlGn_r'``.
        normalize (bool): Whether to normalize values to [0, 1]. Default
            is True.
        figsize (tuple): Figure size. Default is ``(8, 8)``.
        **kwargs: Additional keyword arguments passed to
            ``matplotlib.pyplot.subplots``.

    Returns:
        matplotlib.figure.Figure: The matplotlib figure object.
    """
    import matplotlib.pyplot as plt

    with rasterio.open(uncertainty_path) as src:
        unc = src.read(1).astype(np.float32)
        res = abs(src.transform.a)

    if normalize and unc.max() > unc.min():
        unc = (unc - unc.min()) / (unc.max() - unc.min())
        label = "Uncertainty (Normalized)"
    else:
        label = "Uncertainty"

    fig, ax = plt.subplots(1, 1, figsize=figsize, **kwargs)
    im = ax.imshow(unc, cmap=cmap)
    ax.set_title(f"{label} ({res:.2f} m)")
    ax.set_xlabel("Column")
    ax.set_ylabel("Row")
    fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
    plt.tight_layout()
    return fig

save_geotiff(data, reference_profile, output_path, scale=4)

Save a 2D or 3D NumPy array as a GeoTIFF with super-resolution scaling and corrected georeference.

Parameters:

Name Type Description Default
data ndarray

Image array to save. Can be: - 2D array (H, W) for a single-band image - 3D array (C, H, W) for multi-band images (e.g., RGB+NIR)

required
reference_profile dict

Rasterio metadata from a reference GeoTIFF. Used to preserve CRS, transform, and other metadata.

required
output_path str

Path to save the output GeoTIFF.

required
scale int

Super-resolution scale factor. Default is 4. This adjusts the affine transform to ensure georeference matches the original image.

4
Source code in geoai/tools/sr.py
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
def save_geotiff(
    data: np.ndarray, reference_profile: dict, output_path: str, scale: int = 4
):
    """Save a 2D or 3D NumPy array as a GeoTIFF with super-resolution scaling
    and corrected georeference.

    Args:
        data (np.ndarray): Image array to save. Can be:
            - 2D array (H, W) for a single-band image
            - 3D array (C, H, W) for multi-band images (e.g., RGB+NIR)
        reference_profile (dict): Rasterio metadata from a reference GeoTIFF.
            Used to preserve CRS, transform, and other metadata.
        output_path (str): Path to save the output GeoTIFF.
        scale (int): Super-resolution scale factor. Default is 4. This adjusts
            the affine transform to ensure georeference matches the original
            image.
    """
    os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True)

    if data.ndim == 2:
        data = data[np.newaxis, ...]

    # Update profile and transform
    profile = reference_profile.copy()
    old_transform = profile["transform"]
    new_transform = Affine(
        old_transform.a / scale,
        old_transform.b,
        old_transform.c,
        old_transform.d,
        old_transform.e / scale,
        old_transform.f,
    )
    profile.update(
        dtype=rasterio.float32,
        count=data.shape[0],
        height=data.shape[1],
        width=data.shape[2],
        compress="lzw",
        transform=new_transform,
    )

    with rasterio.open(output_path, "w", **profile) as dst:
        dst.write(data.astype(np.float32))

super_resolution(input_lr_path, output_sr_path, output_uncertainty_path=None, rgb_nir_bands=[1, 2, 3, 4], sampling_steps=100, n_variations=25, scale=4, compute_uncertainty=False, window=None, scale_factor=10000.0, patch_size=128, overlap=16)

Perform super-resolution on RGB+NIR bands of a multispectral GeoTIFF using OpenSR latent diffusion.

The model enhances Sentinel-2 imagery from 10m to 2.5m spatial resolution (4x upsampling) using the LDSR-S2 latent diffusion model from the ESA OpenSR project. For images larger than patch_size, the input is automatically tiled into overlapping patches, each patch is super-resolved, and the results are stitched back together with linear blending.

Parameters:

Name Type Description Default
input_lr_path str

Path to the input low-resolution GeoTIFF.

required
output_sr_path str

Path to save the super-resolution GeoTIFF.

required
output_uncertainty_path str

Path to save the uncertainty map GeoTIFF. Required when compute_uncertainty is True.

None
rgb_nir_bands list[int]

List of 4 one-based band indices corresponding to [R, G, B, NIR] in the input file. Default is [1, 2, 3, 4].

[1, 2, 3, 4]
sampling_steps int

Number of diffusion sampling steps. Higher values produce better results but are slower. Default is 100.

100
n_variations int

Number of stochastic forward passes used to estimate uncertainty. Default is 25.

25
scale int

Super-resolution scale factor. Default is 4.

4
compute_uncertainty bool

Whether to compute an uncertainty map via multiple stochastic forward passes. Default is False.

False
window tuple

Region of interest as (row_off, col_off, height, width) to read a subset of the input image. If None, the entire image is processed.

None
scale_factor float

Divisor to normalize pixel values to the [0, 1] range. For Sentinel-2 L2A BOA reflectance, use 10000.0 (the default). Set to 1.0 if the data is already normalized.

10000.0
patch_size int

Tile size for patched inference. The model expects 128x128 input patches. Default is 128.

128
overlap int

Number of overlapping pixels between adjacent patches to reduce tiling artifacts. Default is 16.

16

Returns:

Type Description
Tuple[ndarray, Optional[ndarray]]

Tuple[np.ndarray, Optional[np.ndarray]]: Tuple containing: - sr_image: Super-resolution image as a NumPy array (4, H, W). - uncertainty: Uncertainty map as a NumPy array (H, W), or None if compute_uncertainty is False.

Raises:

Type Description
ValueError

If rgb_nir_bands does not contain exactly 4 integers, or if compute_uncertainty is True but no output path is given.

ImportError

If the opensr-model package is not installed.

Example

import geoai sr_image, uncertainty = geoai.super_resolution( ... input_lr_path="sentinel2.tif", ... output_sr_path="sr_output.tif", ... rgb_nir_bands=[1, 2, 3, 4], ... window=(500, 500, 128, 128), ... sampling_steps=50, ... )

Source code in geoai/tools/sr.py
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
def super_resolution(
    input_lr_path: str,
    output_sr_path: str,
    output_uncertainty_path: Optional[str] = None,
    rgb_nir_bands: list[int] = [1, 2, 3, 4],
    sampling_steps: int = 100,
    n_variations: int = 25,
    scale: int = 4,
    compute_uncertainty: bool = False,
    window: Optional[tuple] = None,
    scale_factor: float = 10000.0,
    patch_size: int = 128,
    overlap: int = 16,
) -> Tuple[np.ndarray, Optional[np.ndarray]]:
    """Perform super-resolution on RGB+NIR bands of a multispectral GeoTIFF
    using OpenSR latent diffusion.

    The model enhances Sentinel-2 imagery from 10m to 2.5m spatial resolution
    (4x upsampling) using the LDSR-S2 latent diffusion model from the ESA
    OpenSR project. For images larger than ``patch_size``, the input is
    automatically tiled into overlapping patches, each patch is super-resolved,
    and the results are stitched back together with linear blending.

    Args:
        input_lr_path (str): Path to the input low-resolution GeoTIFF.
        output_sr_path (str): Path to save the super-resolution GeoTIFF.
        output_uncertainty_path (str, optional): Path to save the uncertainty
            map GeoTIFF. Required when ``compute_uncertainty`` is True.
        rgb_nir_bands (list[int]): List of 4 one-based band indices
            corresponding to [R, G, B, NIR] in the input file. Default is
            ``[1, 2, 3, 4]``.
        sampling_steps (int): Number of diffusion sampling steps. Higher
            values produce better results but are slower. Default is 100.
        n_variations (int): Number of stochastic forward passes used to
            estimate uncertainty. Default is 25.
        scale (int): Super-resolution scale factor. Default is 4.
        compute_uncertainty (bool): Whether to compute an uncertainty map
            via multiple stochastic forward passes. Default is False.
        window (tuple, optional): Region of interest as
            ``(row_off, col_off, height, width)`` to read a subset of the
            input image. If None, the entire image is processed.
        scale_factor (float): Divisor to normalize pixel values to the
            [0, 1] range. For Sentinel-2 L2A BOA reflectance, use 10000.0
            (the default). Set to 1.0 if the data is already normalized.
        patch_size (int): Tile size for patched inference. The model expects
            128x128 input patches. Default is 128.
        overlap (int): Number of overlapping pixels between adjacent patches
            to reduce tiling artifacts. Default is 16.

    Returns:
        Tuple[np.ndarray, Optional[np.ndarray]]: Tuple containing:
            - sr_image: Super-resolution image as a NumPy array (4, H, W).
            - uncertainty: Uncertainty map as a NumPy array (H, W), or None
              if ``compute_uncertainty`` is False.

    Raises:
        ValueError: If ``rgb_nir_bands`` does not contain exactly 4 integers,
            or if ``compute_uncertainty`` is True but no output path is given.
        ImportError: If the ``opensr-model`` package is not installed.

    Example:
        >>> import geoai
        >>> sr_image, uncertainty = geoai.super_resolution(
        ...     input_lr_path="sentinel2.tif",
        ...     output_sr_path="sr_output.tif",
        ...     rgb_nir_bands=[1, 2, 3, 4],
        ...     window=(500, 500, 128, 128),
        ...     sampling_steps=50,
        ... )
    """
    if len(rgb_nir_bands) != 4:
        raise ValueError("rgb_nir_bands must be a list of 4 integers: [R, G, B, NIR]")
    if not all(isinstance(b, int) for b in rgb_nir_bands):
        raise ValueError(
            "All elements of rgb_nir_bands must be integers. Received: {}".format(
                rgb_nir_bands
            )
        )
    if output_uncertainty_path is not None and not compute_uncertainty:
        compute_uncertainty = True
    if compute_uncertainty and output_uncertainty_path is None:
        raise ValueError(
            "output_uncertainty_path must be provided when compute_uncertainty is True."
        )
    if compute_uncertainty and n_variations <= 3:
        raise ValueError(
            "n_variations must be greater than 3 to compute uncertainty. "
            f"Received: {n_variations}"
        )
    if scale_factor <= 0:
        raise ValueError(f"scale_factor must be positive. Received: {scale_factor}")
    if patch_size <= 0 or overlap < 0 or overlap >= patch_size:
        raise ValueError(
            f"Requires patch_size > 0 and 0 <= overlap < patch_size. "
            f"Received patch_size={patch_size}, overlap={overlap}."
        )
    if not OPENSR_MODEL_AVAILABLE:
        raise ImportError(
            "The 'opensr-model' package is required for super-resolution. "
            "Please install it using: pip install opensr-model\n"
            "Or install GeoAI with the sr optional dependency: pip install geoai-py[sr]"
        )

    device = "cuda" if torch.cuda.is_available() else "cpu"

    # Download configuration YAML from GitHub
    config_url = "https://raw.githubusercontent.com/ESAOpenSR/opensr-model/refs/heads/main/opensr_model/configs/config_10m.yaml"
    print("Downloading model configuration from:", config_url)
    try:
        response = requests.get(config_url)
        response.raise_for_status()
    except requests.RequestException as e:
        print(f"Error downloading model configuration: {e}")
        raise
    config = OmegaConf.load(StringIO(response.text))

    # Initialize latent diffusion model and load pretrained weights.
    # Download checkpoint to the torch hub cache directory instead of cwd.
    model = opensr_model.SRLatentDiffusion(config, device=device)
    ckpt_name = os.path.basename(config.ckpt_version)
    if not ckpt_name or ckpt_name != config.ckpt_version:
        raise ValueError(
            f"Invalid checkpoint name in config: {config.ckpt_version!r}. "
            "Expected a plain filename without path separators."
        )
    ckpt_path = _get_cached_checkpoint(ckpt_name)
    model.load_pretrained(ckpt_path)

    # Load only the specified RGB+NIR bands
    lr_tensor, profile = load_image_tensor(
        image_path=input_lr_path,
        device=device,
        bands=rgb_nir_bands,
        window=window,
        scale_factor=scale_factor,
    )

    # Determine whether patched inference is needed
    _, _, h, w = lr_tensor.shape
    if h > patch_size or w > patch_size:
        sr_image = _process_patched(
            model=model,
            lr_tensor=lr_tensor,
            patch_size=patch_size,
            overlap=overlap,
            scale=scale,
            sampling_steps=sampling_steps,
        )
    else:
        sr_tensor = model.forward(lr_tensor, sampling_steps=sampling_steps)
        sr_image = sr_tensor.squeeze(0).cpu().numpy().astype(np.float32)

    save_geotiff(sr_image, profile, output_sr_path, scale)
    print("Saved super-resolution image to:", output_sr_path)

    # Compute uncertainty map if requested
    uncertainty = None
    if compute_uncertainty:
        if h > patch_size or w > patch_size:
            uncertainty = _process_patched_uncertainty(
                model=model,
                lr_tensor=lr_tensor,
                patch_size=patch_size,
                overlap=overlap,
                scale=scale,
                n_variations=n_variations,
                sampling_steps=sampling_steps,
            )
        else:
            unc_tensor = model.uncertainty_map(
                lr_tensor,
                n_variations=n_variations,
                sampling_steps=sampling_steps,
            )
            uncertainty = unc_tensor.squeeze().cpu().numpy().astype(np.float32)
        save_geotiff(uncertainty, profile, output_uncertainty_path, scale)
        print("Saved uncertainty map to:", output_uncertainty_path)

    return sr_image, uncertainty