Smooth Tiled Inference on GeoTIFF¶
Overview¶
This notebook demonstrates the geoai.inference module, which provides a memory-efficient, artifact-free pipeline for running any PyTorch model on large GeoTIFF rasters.
Standard tiled inference divides a raster into fixed-size windows, runs the model on each tile, and stitches the results back together. This often produces visible seam artifacts at tile boundaries because predictions near tile edges are less reliable. The predict_geotiff function solves this with:
- Overlapping tiles — adjacent tiles share a margin of pixels
- Smooth blending — a 2-D weight mask (linear, cosine, or spline) gives lower weight to predictions near tile edges and higher weight to the centre, producing seamless output
- Windowed I/O — tiles are read from disk one batch at a time, so memory usage scales with
batch_size × tile_size², not the full image - D4 test-time augmentation (TTA) — optional 8-fold augmentation (rotations + flips) that averages predictions over the D4 dihedral group for improved accuracy
Reference: Smoothly-Blend-Image-Patches — https://github.com/Vooban/Smoothly-Blend-Image-Patches
Install packages¶
Uncomment the following line to install the required packages.
# %pip install geoai-py
Import libraries¶
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import geoai
from geoai.inference import (
BlendMode,
create_weight_mask,
predict_geotiff,
d4_forward,
d4_inverse,
d4_tta_forward,
)
Visualise blending strategies¶
The weight mask controls how overlapping tile predictions are blended. A value of 1 at a pixel means that tile's prediction is used with full weight; 0 means it contributes nothing. Adjacent tiles then produce a weighted average with seamless transitions.
tile_size = 256
overlap = 64
modes = ["none", "linear", "cosine", "spline"]
fig, axes = plt.subplots(1, 4, figsize=(16, 4))
for ax, mode in zip(axes, modes):
mask = create_weight_mask(tile_size, overlap, mode=mode)
im = ax.imshow(mask, cmap="viridis", vmin=0, vmax=1)
ax.set_title(f"{mode!r} (overlap={overlap})", fontsize=11)
ax.axis("off")
plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
plt.suptitle(
f"2-D blending weight masks (tile_size={tile_size})",
fontsize=13,
y=1.02,
)
plt.tight_layout()
plt.show()
# 1-D cross-section through the centre row
fig, ax = plt.subplots(figsize=(9, 3))
for mode in ["linear", "cosine", "spline"]:
mask = create_weight_mask(tile_size, overlap, mode=mode)
ax.plot(mask[tile_size // 2, :], label=mode)
ax.axvspan(0, overlap, alpha=0.08, color="red", label="overlap zone")
ax.axvspan(tile_size - overlap, tile_size, alpha=0.08, color="red")
ax.set_xlabel("pixel index")
ax.set_ylabel("weight")
ax.set_title("Centre-row cross-section of weight mask")
ax.legend()
plt.tight_layout()
plt.show()
Download sample data¶
We use a small NAIP aerial image for demonstration.
import os
url = "https://huggingface.co/datasets/giswqs/geospatial/resolve/main/naip_train.tif"
input_path = "naip.tif"
if not os.path.exists(input_path):
geoai.download_file(url, input_path)
geoai.print_raster_info(input_path)
Visualise input raster¶
geoai.view_raster(input_path)
Define a demo model¶
For this demonstration we create a lightweight model with global average pooling. The pooling creates a tile-level context vector, so the same pixel gets a slightly different prediction depending on which tile it belongs to. This is realistic — many segmentation models include global context modules — and it produces visible tile-boundary artifacts when blending is disabled.
class SimpleSegModel(nn.Module):
"""Deterministic model with global context for visible tile-boundary artifacts.
Computes band-ratio features (NDVI-like, greenness, brightness) then
subtracts each tile's global average. The global pooling makes
predictions tile-dependent — different tiles have different means —
so the same pixel receives different predictions depending on which
tile contains it. This creates visible seam artifacts when
blending is disabled, which is exactly what the demo is designed to show.
"""
def __init__(self, in_channels: int = 4):
super().__init__()
self.features = nn.Conv2d(in_channels, 4, 3, padding=1, bias=True)
self.gap = nn.AdaptiveAvgPool2d(1)
self.head = nn.Conv2d(4, 1, 1, bias=True)
self.sigmoid = nn.Sigmoid()
with torch.no_grad():
self.features.weight.zero_()
# Feature 0: smoothed brightness (average of RGB)
for c in range(3):
self.features.weight[0, c, :, :] = torch.ones(3, 3) / 9.0
self.features.bias[0] = -1.5
# Feature 1: NDVI-like (NIR - Red)
self.features.weight[1, 3, 1, 1] = 2.0 # NIR
self.features.weight[1, 0, 1, 1] = -2.0 # -Red
self.features.bias[1] = 0.0
# Feature 2: greenness (Green - Red)
self.features.weight[2, 1, 1, 1] = 2.0
self.features.weight[2, 0, 1, 1] = -2.0
self.features.bias[2] = 0.0
# Feature 3: NIR intensity
self.features.weight[3, 3, 1, 1] = 3.0
self.features.bias[3] = -1.0
# Head: combine features
self.head.weight[0, 0, 0, 0] = -1.5 # dark areas -> high
self.head.weight[0, 1, 0, 0] = 3.0 # vegetation -> high
self.head.weight[0, 2, 0, 0] = 1.0 # green -> high
self.head.weight[0, 3, 0, 0] = 0.5
self.head.bias.fill_(0.0)
def forward(self, x):
feat = self.features(x)
# Subtract tile mean — this makes output tile-dependent
feat = feat - self.gap(feat)
return self.sigmoid(self.head(feat) * 5.0)
model = SimpleSegModel(in_channels=4)
model.eval()
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")
Define custom preprocessing¶
NAIP imagery may be stored as uint8 (0–255) or uint16 (0–10 000). This function detects the range and normalises to [0, 1].
def naip_preprocess(tile: np.ndarray) -> np.ndarray:
"""Normalise a NAIP tile (C, H, W) to float32 in [0, 1].
Handles both uint8 (0-255) and uint16 (0-10000) NAIP imagery.
"""
tile = tile.astype(np.float32)
max_val = tile.max()
if max_val > 255:
# uint16 NAIP data (values typically up to ~10000)
tile = np.clip(tile, 0, 10000) / 10000.0
elif max_val > 1.5:
# uint8 NAIP data (values 0-255)
tile = tile / 255.0
return np.nan_to_num(tile, nan=0.0)
Run inference — compare blending strategies¶
We run the same model three times with different blend modes so the effect on tile boundary smoothness can be compared.
output_no_blend = "output_none.tif"
output_spline = "output_spline.tif"
# No blending — hard tile boundaries (last-write-wins)
predict_geotiff(
model=model,
input_raster=input_path,
output_raster=output_no_blend,
tile_size=256,
overlap=64,
batch_size=4,
num_classes=1,
blend_mode="none",
preprocess_fn=naip_preprocess,
verbose=True,
)
# Spline blending — smooth, artifact-free output
predict_geotiff(
model=model,
input_raster=input_path,
output_raster=output_spline,
tile_size=256,
overlap=64,
batch_size=4,
num_classes=1,
blend_mode="spline",
preprocess_fn=naip_preprocess,
verbose=True,
)
Visualise predictions¶
import rasterio
with rasterio.open(output_no_blend) as src:
pred_none = src.read(1)
with rasterio.open(output_spline) as src:
pred_spline = src.read(1)
fig, axes = plt.subplots(1, 2, figsize=(14, 6))
kw = dict(cmap="RdYlGn", vmin=0, vmax=1)
axes[0].imshow(pred_none, **kw)
axes[0].set_title('blend_mode="none" (hard tile edges visible)', fontsize=11)
axes[0].axis("off")
axes[1].imshow(pred_spline, **kw)
axes[1].set_title('blend_mode="spline" (seamless)', fontsize=11)
axes[1].axis("off")
plt.suptitle("Prediction comparison: no blending vs spline blending", fontsize=13)
plt.tight_layout()
plt.show()
Run inference with D4 test-time augmentation¶
D4 TTA applies all 8 transforms of the dihedral group (identity, 3 rotations, horizontal flip, vertical flip, and 2 diagonal flips), runs the model on each, inverts the transforms, then averages the 8 predictions. This typically improves accuracy but increases compute by 8×.
output_tta = "output_tta.tif"
predict_geotiff(
model=model,
input_raster=input_path,
output_raster=output_tta,
tile_size=256,
overlap=64,
batch_size=4,
num_classes=1,
blend_mode="spline",
tta=True, # enable D4 TTA
preprocess_fn=naip_preprocess,
verbose=True,
)
Multi-class segmentation output¶
predict_geotiff supports models that output multiple channels (e.g., class probability maps). Set num_classes to match the model's output channels, and use postprocess_fn to apply argmax.
class MulticlassModel(nn.Module):
"""Deterministic multi-class model with spatially varying output.
Each class responds to a different spectral feature of the NAIP
imagery (red surfaces, green vegetation, NIR-bright, shadows, edges).
Global average pooling ensures tile-dependent predictions.
All weights are fixed — no randomness.
"""
def __init__(self, in_channels: int = 4, num_classes: int = 5):
super().__init__()
self.features = nn.Conv2d(in_channels, num_classes, 3, padding=1, bias=True)
self.gap = nn.AdaptiveAvgPool2d(1)
self.softmax = nn.Softmax(dim=1)
with torch.no_grad():
self.features.weight.zero_()
# Class 0: high red, low NIR (impervious surfaces / roofs)
self.features.weight[0, 0, 1, 1] = 4.0
self.features.weight[0, 3, 1, 1] = -2.0
self.features.bias[0] = 0.0
# Class 1: high green (lawns / grass)
self.features.weight[1, 1, 1, 1] = 4.0
self.features.weight[1, 0, 1, 1] = -2.0
self.features.bias[1] = 0.0
# Class 2: high NIR with high NDVI (trees / dense vegetation)
self.features.weight[2, 3, 1, 1] = 4.0
self.features.weight[2, 0, 1, 1] = -3.0
self.features.bias[2] = 0.0
# Class 3: low all bands (shadows / dark features)
self.features.weight[3, 0, 1, 1] = -2.0
self.features.weight[3, 1, 1, 1] = -2.0
self.features.weight[3, 2, 1, 1] = -2.0
self.features.bias[3] = 3.0
# Class 4: edges (Sobel on brightness)
sobel = torch.tensor(
[[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=torch.float32
)
for c in range(3):
self.features.weight[4, c] = sobel / 3.0
self.features.bias[4] = 0.0
def forward(self, x):
feat = self.features(x)
feat = feat - self.gap(feat) # tile-dependent normalization
return self.softmax(feat * 3.0)
mc_model = MulticlassModel(in_channels=4, num_classes=5)
mc_model.eval()
output_multiclass = "output_multiclass.tif"
predict_geotiff(
model=mc_model,
input_raster=input_path,
output_raster=output_multiclass,
tile_size=256,
overlap=64,
batch_size=4,
num_classes=5,
output_dtype="float32",
blend_mode="spline",
preprocess_fn=naip_preprocess,
postprocess_fn=lambda arr: np.argmax(arr, axis=0)[np.newaxis, :, :].astype(
np.float32
),
verbose=True,
)
with rasterio.open(output_multiclass) as src:
class_map = src.read(1)
plt.figure(figsize=(7, 6))
plt.imshow(class_map, cmap="tab10", vmin=0, vmax=4, interpolation="nearest")
cbar = plt.colorbar(label="class index", ticks=[0, 1, 2, 3, 4])
cbar.ax.set_yticklabels(["0", "1", "2", "3", "4"])
plt.title("Argmax class map (5-class model with spline blending)")
plt.axis("off")
plt.tight_layout()
plt.show()
Using D4 transforms directly¶
The D4 transform functions are also available as standalone utilities for custom inference loops.
# Create a synthetic batch
batch = torch.randn(1, 4, 64, 64)
# Apply all 8 D4 transforms
augmented = d4_forward(batch)
print(f"Number of augmented views: {len(augmented)}")
# Undo the transforms
restored = d4_inverse(augmented)
# Verify roundtrip
for i, t in enumerate(restored):
max_diff = (t - batch).abs().max().item()
print(f" Transform {i}: max roundtrip error = {max_diff:.2e}")
Summary¶
| Feature | Parameter | Default |
|---|---|---|
| Tile size | tile_size |
256 |
| Overlap | overlap |
64 |
| Blending | blend_mode |
"spline" |
| TTA | tta |
False |
| Custom normalisation | preprocess_fn |
None (auto) |
| Custom postprocessing | postprocess_fn |
None (no-op) |
| Output dtype | output_dtype |
"float32" |
Choose blend_mode="spline" for the smoothest results. Enable tta=True when prediction quality matters more than speed.