Loss Functions for Semantic Segmentation¶
Overview¶
Class imbalance is a common challenge in geospatial semantic segmentation — rare land cover types (e.g., wetlands, impervious surfaces) are often vastly outnumbered by dominant classes (e.g., forest, water). Standard cross-entropy loss treats all pixels equally, which can cause the model to underperform on minority classes.
GeoAI provides several loss functions designed to address class imbalance:
| Loss | Type | Best For |
|---|---|---|
CrossEntropyLoss |
Distribution-based | Balanced datasets |
FocalLoss |
Distribution-based | Moderate imbalance (down-weights easy examples) |
DiceLoss |
Region-based | Overlap-focused training |
TverskyLoss |
Region-based | Controlling FP/FN trade-off (improve recall) |
UnifiedFocalLoss |
Compound (distribution + region) | Severe imbalance (combines focal CE + focal Tversky) |
The Unified Focal Loss implements the framework from Yeung et al. (2021), inspired by the terrainseg package by Maxwell (2026).
Install packages¶
To use the geoai-py package, ensure it is installed in your environment. Uncomment the command below if needed.
# %pip install geoai-py
Import libraries¶
import torch
from geoai import (
DiceLoss,
FocalLoss,
TverskyLoss,
UnifiedFocalLoss,
get_landcover_loss_function,
)
Create synthetic data¶
We'll create a small synthetic example to demonstrate the loss functions. The tensors simulate a 3-class segmentation task with a batch of 2 images of size 64x64.
torch.manual_seed(42)
num_classes = 3
batch_size = 2
h, w = 64, 64
# Model predictions (logits): shape (N, C, H, W)
logits = torch.randn(batch_size, num_classes, h, w)
# Ground truth labels: shape (N, H, W) with class indices
# Simulate class imbalance: ~80% class 0, ~15% class 1, ~5% class 2
probs = torch.tensor([0.80, 0.15, 0.05])
targets = torch.multinomial(probs.expand(batch_size * h * w, -1), 1).squeeze()
targets = targets.reshape(batch_size, h, w)
# Check class distribution
for c in range(num_classes):
count = (targets == c).sum().item()
total = targets.numel()
print(f"Class {c}: {count}/{total} pixels ({100 * count / total:.1f}%)")
dice_loss = DiceLoss(smooth=1.0)
loss = dice_loss(logits, targets)
print(f"Dice Loss: {loss.item():.4f}")
You can also pass per-class weights to emphasise rare classes:
# Give more weight to the rare classes
class_weights = torch.tensor([0.5, 2.0, 5.0])
weighted_dice = DiceLoss(weight=class_weights)
loss = weighted_dice(logits, targets)
print(f"Weighted Dice Loss: {loss.item():.4f}")
TverskyLoss¶
Tversky loss generalises Dice by allowing asymmetric weighting of false positives (FP) and false negatives (FN):
alphacontrols the FP penaltybetacontrols the FN penaltyalpha = beta = 0.5recovers the standard Dice loss- Setting
beta > alphapenalises missed detections more, improving recall on rare classes
Reference: Salehi et al. (2017), "Tversky loss function for image segmentation using 3D fully convolutional deep networks."
# Symmetric (equivalent to Dice)
tversky_sym = TverskyLoss(alpha=0.5, beta=0.5)
loss_sym = tversky_sym(logits, targets)
print(f"Tversky (symmetric, alpha=0.5, beta=0.5): {loss_sym.item():.4f}")
# Asymmetric — penalise false negatives more to boost recall
tversky_fn = TverskyLoss(alpha=0.3, beta=0.7)
loss_fn = tversky_fn(logits, targets)
print(f"Tversky (recall-focused, alpha=0.3, beta=0.7): {loss_fn.item():.4f}")
UnifiedFocalLoss¶
The Unified Focal Loss (Yeung et al., 2021) combines a distribution-based component (focal cross-entropy) with a region-based component (focal Tversky) into a single compound loss. This is particularly effective for severe class imbalance in geospatial segmentation.
The implementation was inspired by the terrainseg package by Maxwell (2026).
Key parameters:
lambda_— balance between distribution and region losses (0.0 = pure focal Tversky, 1.0 = pure focal CE)gamma— focusing parameter that down-weights easy examplesdelta— Tversky FN weight (values > 0.5 emphasise recall)use_log_cosh— optional gradient smoothing vialog(cosh(loss))
# Default: balanced 50/50 mix of focal CE and focal Tversky
ufl = UnifiedFocalLoss(lambda_=0.5, gamma=0.75, delta=0.6)
loss = ufl(logits, targets)
print(f"Unified Focal Loss (default): {loss.item():.4f}")
# Region-heavy: emphasise the Tversky component
ufl_region = UnifiedFocalLoss(lambda_=0.25, gamma=0.75, delta=0.7)
loss_region = ufl_region(logits, targets)
print(f"Unified Focal Loss (region-heavy): {loss_region.item():.4f}")
# With log-cosh gradient smoothing
ufl_smooth = UnifiedFocalLoss(lambda_=0.5, gamma=0.75, delta=0.6, use_log_cosh=True)
loss_smooth = ufl_smooth(logits, targets)
print(f"Unified Focal Loss (log-cosh): {loss_smooth.item():.4f}")
Compare all loss functions¶
Let's compare the loss values across all available loss functions on the same synthetic data:
loss_fns = {
"CrossEntropy": torch.nn.CrossEntropyLoss(),
"Focal (gamma=2)": FocalLoss(gamma=2.0),
"Dice": DiceLoss(),
"Tversky (a=0.3, b=0.7)": TverskyLoss(alpha=0.3, beta=0.7),
"Unified Focal": UnifiedFocalLoss(lambda_=0.5, gamma=0.75, delta=0.6),
}
print(f"{'Loss Function':<30} {'Value':>10}")
print("-" * 42)
for name, fn in loss_fns.items():
val = fn(logits, targets).item()
print(f"{name:<30} {val:>10.4f}")
Using the loss factory¶
The get_landcover_loss_function factory creates configured loss functions by name. This is what train_segmentation_landcover uses internally.
for name in ["crossentropy", "focal", "dice", "tversky", "unified_focal"]:
fn = get_landcover_loss_function(name, device=torch.device("cpu"))
val = fn(logits, targets).item()
print(f"{name:<20} -> {type(fn).__name__:<30} loss={val:.4f}")
Using with train_segmentation_landcover¶
When training a land cover segmentation model, you can select the loss function via the loss_function parameter. Here is an example (not executed) showing the key parameters:
import geoai
# Train with Unified Focal Loss for class-imbalanced land cover data
model = geoai.train_segmentation_landcover(
images_dir="landcover/images",
labels_dir="landcover/labels",
output_dir="landcover/models",
num_classes=13,
num_epochs=50,
# Loss function selection
loss_function="unified_focal", # or "dice", "tversky", "focal", "crossentropy"
# Unified Focal Loss parameters
ufl_lambda=0.5, # balance: 0.0=pure region, 1.0=pure distribution
ufl_gamma=0.75, # focusing parameter
ufl_delta=0.6, # FN weight (>0.5 = recall-focused)
# Class weighting
use_class_weights=True, # auto-compute inverse-frequency weights
ignore_index=0, # ignore background class
)
Using with other training pipelines¶
The core segmentation training functions in geoai accept a loss_fn parameter, so you can pass any of the loss classes directly:
train_segmentation_model (segmentation-models-pytorch based):
from geoai import UnifiedFocalLoss
loss = UnifiedFocalLoss(lambda_=0.5, gamma=0.75, delta=0.6)
geoai.train_segmentation_model(
images_dir="data/images",
labels_dir="data/labels",
output_dir="output",
num_classes=5,
loss_fn=loss,
)
train_timm_segmentation_model (timm-based):
from geoai import DiceLoss
geoai.train_timm_segmentation_model(
images_dir="data/images",
labels_dir="data/labels",
output_dir="output",
num_classes=5,
loss_fn=DiceLoss(ignore_index=0),
)
train_dinov3_segmentation (DINOv3 finetuning):
from geoai import TverskyLoss
loss = TverskyLoss(alpha=0.3, beta=0.7, ignore_index=255)
# Pass to train_dinov3_segmentation via loss_fn parameter
These functions also support class_weights for weighted CrossEntropyLoss when you don't need a custom loss function. Note that train_segmentation_landcover uses the loss_function string parameter as shown in the previous section.
Handling ignore_index¶
All loss functions support an ignore_index parameter to exclude specific pixels (e.g., background, no-data) from the loss computation:
# Mark some pixels as "no data" (class 255)
targets_with_nodata = targets.clone()
targets_with_nodata[:, :5, :5] = 255
# These pixels are excluded from the loss
ufl_ignore = UnifiedFocalLoss(ignore_index=255)
loss = ufl_ignore(logits, targets_with_nodata)
print(f"Unified Focal Loss (with ignore_index=255): {loss.item():.4f}")
Choosing a loss function¶
Guidelines for selecting a loss function:
- Balanced classes —
crossentropyis a solid default. - Moderate imbalance —
focal(gamma=2) down-weights easy, well-classified pixels. - Severe imbalance where recall matters —
tverskywithbeta > alpha(e.g., 0.7/0.3) directly penalises missed detections. - Severe imbalance, general —
unified_focalcombines the strengths of both distribution and region-based approaches. Start with the defaults (lambda_=0.5,gamma=0.75,delta=0.6) and tune from there. - Overlap-focused training —
diceoptimises the Dice coefficient directly, useful when your evaluation metric is IoU or F1.
For all loss functions, enabling use_class_weights=True in train_segmentation_landcover will automatically compute inverse-frequency weights from your training labels, providing an additional layer of class imbalance handling.
References¶
- Yeung, M., Sala, E., Schönlieb, C.-B., & Rundo, L. (2021). Unified Focal loss: Generalising Dice and cross entropy-based losses to handle class imbalanced medical image segmentation. Computerized Medical Imaging and Graphics, 95, 102026. https://doi.org/10.1016/j.compmedimag.2021.102026
- Salehi, S. S. M., Erdogmus, D., & Gholipour, A. (2017). Tversky loss function for image segmentation using 3D fully convolutional deep networks. MLMI Workshop, MICCAI.
- Lin, T. Y., Goyal, P., Girshick, R., He, K., & Dollar, P. (2017). Focal loss for dense object detection. ICCV.
- Maxwell, A. (2026). terrainseg. https://github.com/maxwell-geospatial/terrainseg