Skip to content

inference module

Memory-efficient tiled inference with blending and test-time augmentation.

Provides a generic sliding-window inference pipeline that works with any PyTorch segmentation or regression model on GeoTIFF rasters. Key features:

  • Windowed I/O -- reads tiles directly via rasterio windows, avoiding full-image input memory allocation.
  • Multiple blending strategies -- linear ramp, raised cosine, and spline windows for seamless tile stitching.
  • D4 test-time augmentation -- optional 8-fold augmentation using the dihedral group (identity, 3 rotations, horizontal flip, vertical flip, 2 diagonal flips).
References
  • Spline blending: https://github.com/Vooban/Smoothly-Blend-Image-Patches
  • GitHub issue: https://github.com/opengeos/geoai/issues/87

BlendMode

Bases: str, Enum

Blending strategy for overlapping tile predictions.

Attributes:

Name Type Description
NONE

Uniform averaging -- all pixels are weighted equally (1.0), so overlapping tiles are simply averaged without tapering.

LINEAR

Linear ramp from 0 at edges to 1 at center.

COSINE

Raised-cosine (Hann) taper in the overlap region.

SPLINE

Powered raised-cosine taper for smooth transitions in the overlap zones. Requires overlap <= tile_size // 2.

Source code in geoai/inference.py
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
class BlendMode(str, Enum):
    """Blending strategy for overlapping tile predictions.

    Attributes:
        NONE: Uniform averaging -- all pixels are weighted equally (1.0),
            so overlapping tiles are simply averaged without tapering.
        LINEAR: Linear ramp from 0 at edges to 1 at center.
        COSINE: Raised-cosine (Hann) taper in the overlap region.
        SPLINE: Powered raised-cosine taper for smooth transitions in
            the overlap zones. Requires ``overlap <= tile_size // 2``.
    """

    NONE = "none"
    LINEAR = "linear"
    COSINE = "cosine"
    SPLINE = "spline"

create_weight_mask(tile_size, overlap, mode=BlendMode.SPLINE, power=2)

Create a 2D weight mask for blending overlapping tiles.

Parameters:

Name Type Description Default
tile_size int

Size of each square tile in pixels.

required
overlap int

Number of pixels of overlap between adjacent tiles.

required
mode Union[str, BlendMode]

Blending strategy. One of "none", "linear", "cosine", or "spline".

SPLINE
power int

Exponent for spline mode. Higher values concentrate weight toward the center. Ignored for other modes.

2

Returns:

Type Description
ndarray

numpy.ndarray: Float32 array of shape (tile_size, tile_size) with values in [0, 1].

Raises:

Type Description
ValueError

If mode is not a recognized blending strategy, or if overlap is negative or >= tile_size.

Example

from geoai.inference import create_weight_mask mask = create_weight_mask(256, 64, mode="spline") mask.shape (256, 256)

Source code in geoai/inference.py
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
def create_weight_mask(
    tile_size: int,
    overlap: int,
    mode: Union[str, BlendMode] = BlendMode.SPLINE,
    power: int = 2,
) -> np.ndarray:
    """Create a 2D weight mask for blending overlapping tiles.

    Args:
        tile_size: Size of each square tile in pixels.
        overlap: Number of pixels of overlap between adjacent tiles.
        mode: Blending strategy. One of ``"none"``, ``"linear"``,
            ``"cosine"``, or ``"spline"``.
        power: Exponent for spline mode. Higher values concentrate
            weight toward the center. Ignored for other modes.

    Returns:
        numpy.ndarray: Float32 array of shape ``(tile_size, tile_size)``
            with values in ``[0, 1]``.

    Raises:
        ValueError: If *mode* is not a recognized blending strategy, or
            if *overlap* is negative or >= *tile_size*.

    Example:
        >>> from geoai.inference import create_weight_mask
        >>> mask = create_weight_mask(256, 64, mode="spline")
        >>> mask.shape
        (256, 256)
    """
    if overlap < 0 or overlap >= tile_size:
        raise ValueError(
            f"overlap must be >= 0 and < tile_size ({tile_size}), got {overlap}"
        )

    mode = BlendMode(mode)

    if overlap == 0 or mode == BlendMode.NONE:
        return np.ones((tile_size, tile_size), dtype=np.float32)

    if mode == BlendMode.LINEAR:
        # Compute per-pixel distance from nearest edge, capped at overlap.
        # np.minimum avoids write-order corruption when overlap > tile_size // 2.
        left = np.arange(tile_size, dtype=np.float32)
        right = np.arange(tile_size - 1, -1, -1, dtype=np.float32)
        ramp = np.minimum(np.minimum(left, right), overlap) / overlap
        return np.outer(ramp, ramp)

    if mode == BlendMode.COSINE:
        # Same edge-distance approach, then apply raised-cosine taper.
        left = np.arange(tile_size, dtype=np.float32)
        right = np.arange(tile_size - 1, -1, -1, dtype=np.float32)
        dist = np.minimum(np.minimum(left, right), overlap) / overlap
        w = (0.5 * (1.0 - np.cos(np.pi * dist))).astype(np.float32)
        return np.outer(w, w)

    if mode == BlendMode.SPLINE:
        if overlap > tile_size // 2:
            raise ValueError(
                f"For spline blending, overlap must be <= tile_size // 2 "
                f"({tile_size // 2}), got {overlap}. Use 'linear' or "
                f"'cosine' blending for larger overlaps."
            )
        w1d = _spline_window_1d(tile_size, overlap, power=power).astype(np.float32)
        return np.outer(w1d, w1d)

    raise ValueError(f"Unknown blend mode: {mode!r}")

d4_forward(tensor)

Apply all 8 D4 dihedral group transforms to a batch of images.

The D4 group consists of the identity, three 90-degree rotations, horizontal flip, vertical flip, and two diagonal flips.

Parameters:

Name Type Description Default
tensor 'torch.Tensor'

Input tensor of shape (B, C, H, W).

required

Returns:

Name Type Description
list List['torch.Tensor']

List of 8 tensors, each of shape (B, C, H, W).

Source code in geoai/inference.py
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
def d4_forward(tensor: "torch.Tensor") -> List["torch.Tensor"]:
    """Apply all 8 D4 dihedral group transforms to a batch of images.

    The D4 group consists of the identity, three 90-degree rotations,
    horizontal flip, vertical flip, and two diagonal flips.

    Args:
        tensor: Input tensor of shape ``(B, C, H, W)``.

    Returns:
        list: List of 8 tensors, each of shape ``(B, C, H, W)``.
    """
    import torch  # noqa: F811

    return [
        tensor,  # identity
        torch.rot90(tensor, k=1, dims=[-2, -1]),
        torch.rot90(tensor, k=2, dims=[-2, -1]),
        torch.rot90(tensor, k=3, dims=[-2, -1]),
        torch.flip(tensor, dims=[-1]),  # horizontal flip
        torch.flip(tensor, dims=[-2]),  # vertical flip
        torch.flip(torch.rot90(tensor, k=1, dims=[-2, -1]), dims=[-1]),
        torch.flip(torch.rot90(tensor, k=1, dims=[-2, -1]), dims=[-2]),
    ]

d4_inverse(tensors)

Apply the inverse D4 transforms to undo :func:d4_forward.

Parameters:

Name Type Description Default
tensors List['torch.Tensor']

List of 8 tensors from :func:d4_forward, each of shape (B, C, H, W).

required

Returns:

Name Type Description
list List['torch.Tensor']

List of 8 tensors, each transformed back to the original orientation.

Source code in geoai/inference.py
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
def d4_inverse(tensors: List["torch.Tensor"]) -> List["torch.Tensor"]:
    """Apply the inverse D4 transforms to undo :func:`d4_forward`.

    Args:
        tensors: List of 8 tensors from :func:`d4_forward`, each of
            shape ``(B, C, H, W)``.

    Returns:
        list: List of 8 tensors, each transformed back to the original
            orientation.
    """
    import torch  # noqa: F811

    return [
        tensors[0],  # identity
        torch.rot90(tensors[1], k=3, dims=[-2, -1]),
        torch.rot90(tensors[2], k=2, dims=[-2, -1]),
        torch.rot90(tensors[3], k=1, dims=[-2, -1]),
        torch.flip(tensors[4], dims=[-1]),
        torch.flip(tensors[5], dims=[-2]),
        torch.rot90(torch.flip(tensors[6], dims=[-1]), k=3, dims=[-2, -1]),
        torch.rot90(torch.flip(tensors[7], dims=[-2]), k=3, dims=[-2, -1]),
    ]

d4_tta_forward(model, batch)

Run inference with D4 test-time augmentation and average results.

Applies all 8 D4 transforms, runs the model on each, inverts the transforms, and averages the predictions. This can improve prediction quality at the cost of 8x compute.

Parameters:

Name Type Description Default
model 'torch.nn.Module'

PyTorch model that accepts (B, C, H, W) input and returns (B, num_classes, H, W) output.

required
batch 'torch.Tensor'

Input tensor of shape (B, C, H, W).

required

Returns:

Type Description
'torch.Tensor'

torch.Tensor: Averaged prediction tensor of shape (B, num_classes, H, W).

Source code in geoai/inference.py
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
def d4_tta_forward(
    model: "torch.nn.Module",
    batch: "torch.Tensor",
) -> "torch.Tensor":
    """Run inference with D4 test-time augmentation and average results.

    Applies all 8 D4 transforms, runs the model on each, inverts the
    transforms, and averages the predictions.  This can improve
    prediction quality at the cost of 8x compute.

    Args:
        model: PyTorch model that accepts ``(B, C, H, W)`` input and
            returns ``(B, num_classes, H, W)`` output.
        batch: Input tensor of shape ``(B, C, H, W)``.

    Returns:
        torch.Tensor: Averaged prediction tensor of shape
            ``(B, num_classes, H, W)``.
    """
    import torch  # noqa: F811

    augmented = d4_forward(batch)
    outputs = [model(aug) for aug in augmented]
    restored = d4_inverse(outputs)
    return torch.stack(restored).mean(dim=0)

predict_geotiff(model, input_raster, output_raster, tile_size=256, overlap=64, batch_size=4, input_bands=None, num_classes=1, output_dtype='float32', output_nodata=-9999.0, blend_mode='spline', blend_power=2, tta=False, preprocess_fn=None, postprocess_fn=None, device=None, compress='lzw', verbose=True)

Run tiled inference on a GeoTIFF with blending and optional TTA.

Reads tiles from input_raster using rasterio windowed I/O, runs each batch through model, blends overlapping predictions with the chosen weight mask, and writes results to output_raster. Memory usage is proportional to batch_size * tile_size**2 for input reads rather than the full image.

Parameters:

Name Type Description Default
model 'torch.nn.Module'

PyTorch model accepting (B, C, H, W) float tensors and returning (B, num_classes, H, W) predictions.

required
input_raster str

Path to the input GeoTIFF file.

required
output_raster str

Path to save the output GeoTIFF.

required
tile_size int

Size of square tiles in pixels.

256
overlap int

Overlap between adjacent tiles in pixels. Using overlap with blending weights eliminates tile-boundary artefacts. Higher values give smoother results at the cost of more computation.

64
batch_size int

Number of tiles per forward pass.

4
input_bands Optional[List[int]]

1-based band indices to read. If None, reads all bands.

None
num_classes int

Number of output channels/classes from the model. Use 1 for regression or binary segmentation.

1
output_dtype str

NumPy dtype string for the output raster (e.g., "float32", "uint8").

'float32'
output_nodata float

NoData value for the output raster.

-9999.0
blend_mode Union[str, BlendMode]

Blending strategy: "none", "linear", "cosine", or "spline".

'spline'
blend_power int

Exponent for spline blending (ignored for other modes).

2
tta bool

If True, apply D4 test-time augmentation. Increases compute by 8x but can improve prediction quality.

False
preprocess_fn Optional[Callable[..., ndarray]]

Optional callable (np.ndarray) -> np.ndarray applied to each tile after reading (e.g., normalization). Input shape is (C, H, W). If None, tiles are cast to float32 and divided by 255 when values exceed 1.5.

None
postprocess_fn Optional[Callable[..., ndarray]]

Optional callable (np.ndarray) -> np.ndarray applied to the final blended output array of shape (num_classes, H, W) before writing (e.g., argmax for classification). If None, no post-processing is applied.

None
device Optional[str]

PyTorch device string (e.g., "cuda", "cpu"). Auto-detected if None.

None
compress str

Compression for the output GeoTIFF.

'lzw'
verbose bool

Print progress information.

True

Returns:

Name Type Description
str str

Path to the output raster.

Raises:

Type Description
FileNotFoundError

If input_raster does not exist.

ValueError

If overlap >= tile_size or overlap < 0.

Example

from geoai.inference import predict_geotiff predict_geotiff( ... model=my_model, ... input_raster="input.tif", ... output_raster="output.tif", ... tile_size=256, ... overlap=64, ... blend_mode="spline", ... tta=False, ... ) 'output.tif'

Source code in geoai/inference.py
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
def predict_geotiff(
    model: "torch.nn.Module",
    input_raster: str,
    output_raster: str,
    tile_size: int = 256,
    overlap: int = 64,
    batch_size: int = 4,
    input_bands: Optional[List[int]] = None,
    num_classes: int = 1,
    output_dtype: str = "float32",
    output_nodata: float = -9999.0,
    blend_mode: Union[str, BlendMode] = "spline",
    blend_power: int = 2,
    tta: bool = False,
    preprocess_fn: Optional[Callable[..., np.ndarray]] = None,
    postprocess_fn: Optional[Callable[..., np.ndarray]] = None,
    device: Optional[str] = None,
    compress: str = "lzw",
    verbose: bool = True,
) -> str:
    """Run tiled inference on a GeoTIFF with blending and optional TTA.

    Reads tiles from *input_raster* using rasterio windowed I/O, runs
    each batch through *model*, blends overlapping predictions with the
    chosen weight mask, and writes results to *output_raster*.  Memory
    usage is proportional to ``batch_size * tile_size**2`` for input
    reads rather than the full image.

    Args:
        model: PyTorch model accepting ``(B, C, H, W)`` float tensors
            and returning ``(B, num_classes, H, W)`` predictions.
        input_raster: Path to the input GeoTIFF file.
        output_raster: Path to save the output GeoTIFF.
        tile_size: Size of square tiles in pixels.
        overlap: Overlap between adjacent tiles in pixels. Using overlap
            with blending weights eliminates tile-boundary artefacts.
            Higher values give smoother results at the cost of more
            computation.
        batch_size: Number of tiles per forward pass.
        input_bands: 1-based band indices to read. If None, reads all
            bands.
        num_classes: Number of output channels/classes from the model.
            Use 1 for regression or binary segmentation.
        output_dtype: NumPy dtype string for the output raster (e.g.,
            ``"float32"``, ``"uint8"``).
        output_nodata: NoData value for the output raster.
        blend_mode: Blending strategy: ``"none"``, ``"linear"``,
            ``"cosine"``, or ``"spline"``.
        blend_power: Exponent for spline blending (ignored for other
            modes).
        tta: If True, apply D4 test-time augmentation. Increases
            compute by 8x but can improve prediction quality.
        preprocess_fn: Optional callable ``(np.ndarray) -> np.ndarray``
            applied to each tile after reading (e.g., normalization).
            Input shape is ``(C, H, W)``.  If None, tiles are cast to
            float32 and divided by 255 when values exceed 1.5.
        postprocess_fn: Optional callable ``(np.ndarray) -> np.ndarray``
            applied to the final blended output array of shape
            ``(num_classes, H, W)`` before writing (e.g., argmax for
            classification).  If None, no post-processing is applied.
        device: PyTorch device string (e.g., ``"cuda"``, ``"cpu"``).
            Auto-detected if None.
        compress: Compression for the output GeoTIFF.
        verbose: Print progress information.

    Returns:
        str: Path to the output raster.

    Raises:
        FileNotFoundError: If *input_raster* does not exist.
        ValueError: If *overlap* >= *tile_size* or *overlap* < 0.

    Example:
        >>> from geoai.inference import predict_geotiff
        >>> predict_geotiff(
        ...     model=my_model,
        ...     input_raster="input.tif",
        ...     output_raster="output.tif",
        ...     tile_size=256,
        ...     overlap=64,
        ...     blend_mode="spline",
        ...     tta=False,
        ... )
        'output.tif'
    """
    import torch
    import rasterio
    from rasterio.windows import Window
    from tqdm.auto import tqdm

    from geoai.utils import get_device

    # ---- validation ----
    if not os.path.exists(input_raster):
        raise FileNotFoundError(f"Input raster not found: {input_raster}")

    if overlap < 0 or overlap >= tile_size:
        raise ValueError(
            f"overlap must be >= 0 and < tile_size ({tile_size}), got {overlap}"
        )

    # Validate output_nodata fits the chosen output_dtype
    out_dt = np.dtype(output_dtype)
    if np.issubdtype(out_dt, np.integer):
        info = np.iinfo(out_dt)
        if not (info.min <= output_nodata <= info.max):
            raise ValueError(
                f"output_nodata={output_nodata} is outside the valid range "
                f"[{info.min}, {info.max}] for output_dtype='{output_dtype}'. "
                f"Choose a nodata value that fits the dtype (e.g., 0 or 255 "
                f"for uint8)."
            )

    if device is None:
        device = get_device()
    else:
        device = torch.device(device)

    preprocess = preprocess_fn if preprocess_fn is not None else _default_preprocess
    stride = tile_size - overlap

    # ---- compute weight mask once ----
    weight_mask = create_weight_mask(
        tile_size, overlap, mode=blend_mode, power=blend_power
    )

    model.to(device)
    model.eval()

    with rasterio.open(input_raster) as src:
        height = src.height
        width = src.width
        profile = src.profile.copy()

        if input_bands is None:
            input_bands = list(range(1, src.count + 1))

        if verbose:
            print(f"Input raster: {width}x{height}, {len(input_bands)} bands")
            print(f"Tile size: {tile_size}, overlap: {overlap}, stride: {stride}")

        # ---- allocate output accumulators ----
        output_sum = np.zeros((num_classes, height, width), dtype=np.float64)
        weight_sum = np.zeros((1, height, width), dtype=np.float64)

        # ---- build tile grid ----
        tiles: List[Tuple[int, int, int, int]] = []
        for row in range(0, height, stride):
            for col in range(0, width, stride):
                row_end = min(row + tile_size, height)
                col_end = min(col + tile_size, width)
                # Pull start back so the tile is tile_size when possible
                row_start = max(0, row_end - tile_size)
                col_start = max(0, col_end - tile_size)
                tiles.append((row_start, col_start, row_end, col_end))

        # Deduplicate tiles that map to the same position at boundaries
        tiles = list(dict.fromkeys(tiles))

        if verbose:
            print(f"Total tiles: {len(tiles)}")

        # ---- process in batches ----
        iterator = range(0, len(tiles), batch_size)
        if verbose:
            iterator = tqdm(iterator, desc="Running inference")

        for batch_start in iterator:
            batch_end = min(batch_start + batch_size, len(tiles))
            batch_tiles = tiles[batch_start:batch_end]

            # Read tiles via windowed I/O
            batch_images = []
            batch_actual_sizes: List[Tuple[int, int]] = []

            for row_start, col_start, row_end, col_end in batch_tiles:
                actual_h = row_end - row_start
                actual_w = col_end - col_start
                batch_actual_sizes.append((actual_h, actual_w))

                window = Window(col_start, row_start, actual_w, actual_h)
                tile_data = src.read(input_bands, window=window).astype(np.float32)

                # Pad undersized edge tiles
                if actual_h != tile_size or actual_w != tile_size:
                    padded = np.zeros(
                        (len(input_bands), tile_size, tile_size),
                        dtype=np.float32,
                    )
                    padded[:, :actual_h, :actual_w] = tile_data
                    tile_data = padded

                tile_data = preprocess(tile_data)
                batch_images.append(tile_data)

            batch_tensor = torch.from_numpy(np.stack(batch_images)).to(device)

            # Inference
            with torch.no_grad():
                if tta:
                    preds = d4_tta_forward(model, batch_tensor)
                else:
                    preds = model(batch_tensor)

            preds = preds.cpu().numpy()

            # Validate model output shape on first batch
            if batch_start == 0:
                pred_channels = preds.shape[1] if preds.ndim == 4 else 1
                if pred_channels != num_classes:
                    raise ValueError(
                        f"Model output has {pred_channels} channels but "
                        f"num_classes={num_classes}. Set num_classes to "
                        f"match the model output."
                    )

            # Handle models that return (B, H, W) instead of (B, 1, H, W)
            if preds.ndim == 3:
                preds = preds[:, np.newaxis, :, :]

            # Accumulate with blending
            for i, (row_start, col_start, row_end, col_end) in enumerate(batch_tiles):
                actual_h, actual_w = batch_actual_sizes[i]
                pred = preds[i]  # (num_classes, tile_size, tile_size)

                # Crop to actual tile dimensions
                pred_crop = pred[:, :actual_h, :actual_w]
                weight_crop = weight_mask[:actual_h, :actual_w]

                # Accumulate weighted prediction and weights
                output_sum[:, row_start:row_end, col_start:col_end] += (
                    pred_crop * weight_crop[np.newaxis, :, :]
                )
                weight_sum[:, row_start:row_end, col_start:col_end] += weight_crop[
                    np.newaxis, :, :
                ]

    # ---- normalize by weights ----
    # valid mask is (1, H, W); NumPy broadcasts it against (num_classes, H, W)
    valid = weight_sum > 0  # (1, H, W)
    output_array = np.where(
        valid,
        output_sum / (weight_sum + 1e-8),
        output_nodata,
    ).astype(np.float32)

    # ---- postprocess ----
    if postprocess_fn is not None:
        output_array = postprocess_fn(output_array)

    # ---- write output ----
    output_dir = os.path.dirname(os.path.abspath(output_raster))
    if output_dir:
        os.makedirs(output_dir, exist_ok=True)

    out_count = output_array.shape[0] if output_array.ndim == 3 else 1
    profile.update(
        count=out_count,
        dtype=output_dtype,
        nodata=output_nodata,
        compress=compress,
    )

    with rasterio.open(output_raster, "w", **profile) as dst:
        if output_array.ndim == 3:
            for band_idx in range(out_count):
                dst.write(output_array[band_idx].astype(output_dtype), band_idx + 1)
        else:
            dst.write(output_array.astype(output_dtype), 1)

    if verbose:
        print(f"Output saved to: {output_raster}")
        print(f"Output dimensions: {width}x{height}")

    return output_raster