Skip to content

embeddings module

Module for working with geospatial embedding datasets from TorchGeo.

This module provides a unified, easy-to-use interface for loading, exploring, visualizing, and analyzing pre-computed Earth embedding datasets introduced in TorchGeo v0.9.0. These embeddings are pre-computed representations from geospatial foundation models that encode satellite imagery into compact vector representations.

Two types of embedding datasets are supported:

  • Patch-based (NonGeoDataset): Each sample is a single embedding vector for a geographic patch (e.g., ClayEmbeddings, MajorTOMEmbeddings). Data is stored in GeoParquet files.

  • Pixel-based (RasterDataset): Each pixel contains an embedding vector, stored as multi-band GeoTIFF rasters (e.g., GoogleSatelliteEmbedding, TesseraEmbeddings).

Example usage:

1
2
3
4
5
6
7
>>> import geoai
>>> # List available embedding datasets
>>> geoai.list_embedding_datasets()
>>> # Load a patch-based dataset
>>> ds = geoai.load_embedding_dataset("clay", root="path/to/data.parquet")
>>> # Load a pixel-based dataset
>>> ds = geoai.load_embedding_dataset("google_satellite", paths="path/to/data")

cluster_embeddings(embeddings, n_clusters=10, method='kmeans', random_state=42, **kwargs)

Cluster embedding vectors using unsupervised methods.

Parameters:

Name Type Description Default
embeddings ndarray

Array of shape (N, D).

required
n_clusters int

Number of clusters.

10
method str

Clustering method. One of "kmeans", "spectral", or "dbscan".

'kmeans'
random_state int

Random seed for reproducibility.

42
**kwargs Any

Additional keyword arguments passed to the clustering algorithm.

{}

Returns:

Type Description
Dict[str, Any]

Dictionary with keys:

Dict[str, Any]
  • "labels": ndarray of cluster assignments of shape (N,)
Dict[str, Any]
  • "model": the fitted clustering model
Dict[str, Any]
  • "n_clusters": effective number of clusters found

Raises:

Type Description
ValueError

If an unsupported method is specified.

Example

import geoai result = geoai.cluster_embeddings(embeddings, n_clusters=5) labels = result["labels"]

Source code in geoai/embeddings.py
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
def cluster_embeddings(
    embeddings: np.ndarray,
    n_clusters: int = 10,
    method: str = "kmeans",
    random_state: int = 42,
    **kwargs: Any,
) -> Dict[str, Any]:
    """Cluster embedding vectors using unsupervised methods.

    Args:
        embeddings: Array of shape ``(N, D)``.
        n_clusters: Number of clusters.
        method: Clustering method. One of ``"kmeans"``, ``"spectral"``,
            or ``"dbscan"``.
        random_state: Random seed for reproducibility.
        **kwargs: Additional keyword arguments passed to the clustering
            algorithm.

    Returns:
        Dictionary with keys:

        - ``"labels"``: ndarray of cluster assignments of shape ``(N,)``
        - ``"model"``: the fitted clustering model
        - ``"n_clusters"``: effective number of clusters found

    Raises:
        ValueError: If an unsupported method is specified.

    Example:
        >>> import geoai
        >>> result = geoai.cluster_embeddings(embeddings, n_clusters=5)
        >>> labels = result["labels"]
    """
    method = method.lower()

    if method == "kmeans":
        from sklearn.cluster import KMeans

        model = KMeans(
            n_clusters=n_clusters, random_state=random_state, n_init=10, **kwargs
        )
        labels = model.fit_predict(embeddings)
    elif method == "spectral":
        from sklearn.cluster import SpectralClustering

        model = SpectralClustering(
            n_clusters=n_clusters, random_state=random_state, **kwargs
        )
        labels = model.fit_predict(embeddings)
    elif method == "dbscan":
        from sklearn.cluster import DBSCAN

        model = DBSCAN(**kwargs)
        labels = model.fit_predict(embeddings)
        n_clusters = len(set(labels)) - (1 if -1 in labels else 0)
    else:
        raise ValueError(
            f"method must be 'kmeans', 'spectral', or 'dbscan', got '{method}'"
        )

    return {
        "labels": labels,
        "model": model,
        "n_clusters": n_clusters if method == "dbscan" else n_clusters,
    }

compare_embeddings(embeddings_a, embeddings_b, metric='cosine')

Compute pairwise similarity between two sets of embeddings.

Useful for change detection between embeddings from different time periods or different sensors.

Parameters:

Name Type Description Default
embeddings_a ndarray

First set of embeddings of shape (N, D).

required
embeddings_b ndarray

Second set of embeddings of shape (N, D). Must have the same number of samples as embeddings_a.

required
metric str

Similarity metric. One of "cosine", "dot", or "euclidean".

'cosine'

Returns:

Type Description
ndarray

Array of shape (N,) with element-wise similarity scores.

Example

import geoai similarity = geoai.compare_embeddings(emb_2020, emb_2024)

Source code in geoai/embeddings.py
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
def compare_embeddings(
    embeddings_a: np.ndarray,
    embeddings_b: np.ndarray,
    metric: str = "cosine",
) -> np.ndarray:
    """Compute pairwise similarity between two sets of embeddings.

    Useful for change detection between embeddings from different time
    periods or different sensors.

    Args:
        embeddings_a: First set of embeddings of shape ``(N, D)``.
        embeddings_b: Second set of embeddings of shape ``(N, D)``.
            Must have the same number of samples as ``embeddings_a``.
        metric: Similarity metric. One of ``"cosine"``, ``"dot"``, or
            ``"euclidean"``.

    Returns:
        Array of shape ``(N,)`` with element-wise similarity scores.

    Example:
        >>> import geoai
        >>> similarity = geoai.compare_embeddings(emb_2020, emb_2024)
    """
    if embeddings_a.shape != embeddings_b.shape:
        raise ValueError(
            f"Shape mismatch: {embeddings_a.shape} vs {embeddings_b.shape}"
        )

    metric = metric.lower()
    if metric == "cosine":
        # Cosine similarity per row
        dot = np.sum(embeddings_a * embeddings_b, axis=1)
        norm_a = np.linalg.norm(embeddings_a, axis=1)
        norm_b = np.linalg.norm(embeddings_b, axis=1)
        return dot / (norm_a * norm_b + 1e-8)
    elif metric == "dot":
        return np.sum(embeddings_a * embeddings_b, axis=1)
    elif metric == "euclidean":
        return np.linalg.norm(embeddings_a - embeddings_b, axis=1)
    else:
        raise ValueError(
            f"metric must be 'cosine', 'dot', or 'euclidean', got '{metric}'"
        )

embedding_similarity(query, embeddings, metric='cosine', top_k=10)

Find the most similar embeddings to a query vector.

Parameters:

Name Type Description Default
query ndarray

Query embedding of shape (D,) or (1, D).

required
embeddings ndarray

Database of embeddings of shape (N, D).

required
metric str

Similarity metric. One of "cosine" or "euclidean".

'cosine'
top_k int

Number of most similar results to return.

10

Returns:

Type Description
Dict[str, ndarray]

Dictionary with keys:

Dict[str, ndarray]
  • "indices": indices of the top-k most similar embeddings
Dict[str, ndarray]
  • "scores": similarity scores (higher is more similar for cosine, lower for euclidean)
Example

import geoai results = geoai.embedding_similarity( ... query=embeddings[0], embeddings=embeddings, top_k=5 ... ) print(results["indices"])

Source code in geoai/embeddings.py
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
def embedding_similarity(
    query: np.ndarray,
    embeddings: np.ndarray,
    metric: str = "cosine",
    top_k: int = 10,
) -> Dict[str, np.ndarray]:
    """Find the most similar embeddings to a query vector.

    Args:
        query: Query embedding of shape ``(D,)`` or ``(1, D)``.
        embeddings: Database of embeddings of shape ``(N, D)``.
        metric: Similarity metric. One of ``"cosine"`` or ``"euclidean"``.
        top_k: Number of most similar results to return.

    Returns:
        Dictionary with keys:

        - ``"indices"``: indices of the top-k most similar embeddings
        - ``"scores"``: similarity scores (higher is more similar for
          cosine, lower for euclidean)

    Example:
        >>> import geoai
        >>> results = geoai.embedding_similarity(
        ...     query=embeddings[0], embeddings=embeddings, top_k=5
        ... )
        >>> print(results["indices"])
    """
    query = np.asarray(query)
    if query.ndim == 1:
        query = query.reshape(1, -1)

    metric = metric.lower()
    if metric == "cosine":
        from sklearn.metrics.pairwise import cosine_similarity

        scores = cosine_similarity(query, embeddings).ravel()
        indices = np.argsort(-scores)[:top_k]
        return {"indices": indices, "scores": scores[indices]}
    elif metric == "euclidean":
        from sklearn.metrics.pairwise import euclidean_distances

        distances = euclidean_distances(query, embeddings).ravel()
        indices = np.argsort(distances)[:top_k]
        return {"indices": indices, "scores": distances[indices]}
    else:
        raise ValueError(f"metric must be 'cosine' or 'euclidean', got '{metric}'")

embedding_to_geotiff(embeddings, bounds, output_path, crs='EPSG:4326')

Save embedding vectors as a multi-band GeoTIFF.

Each embedding dimension becomes a separate band in the output raster.

Parameters:

Name Type Description Default
embeddings ndarray

Array of shape (H, W, D) or (D, H, W).

required
bounds Tuple[float, float, float, float]

Geographic bounds as (west, south, east, north).

required
output_path str

Path to save the GeoTIFF file.

required
crs str

Coordinate reference system string.

'EPSG:4326'

Returns:

Type Description
str

The output file path.

Example

import geoai geoai.embedding_to_geotiff( ... embeddings, bounds=(-122, 37, -121, 38), ... output_path="embeddings.tif" ... )

Source code in geoai/embeddings.py
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
def embedding_to_geotiff(
    embeddings: np.ndarray,
    bounds: Tuple[float, float, float, float],
    output_path: str,
    crs: str = "EPSG:4326",
) -> str:
    """Save embedding vectors as a multi-band GeoTIFF.

    Each embedding dimension becomes a separate band in the output raster.

    Args:
        embeddings: Array of shape ``(H, W, D)`` or ``(D, H, W)``.
        bounds: Geographic bounds as ``(west, south, east, north)``.
        output_path: Path to save the GeoTIFF file.
        crs: Coordinate reference system string.

    Returns:
        The output file path.

    Example:
        >>> import geoai
        >>> geoai.embedding_to_geotiff(
        ...     embeddings, bounds=(-122, 37, -121, 38),
        ...     output_path="embeddings.tif"
        ... )
    """
    import rasterio
    from rasterio.transform import from_bounds

    if embeddings.ndim != 3:
        raise ValueError(f"Expected 3D array, got shape {embeddings.shape}")

    # Ensure (D, H, W) format
    if embeddings.shape[0] > embeddings.shape[2]:
        # Likely (H, W, D)
        embeddings = np.transpose(embeddings, (2, 0, 1))

    bands, height, width = embeddings.shape
    west, south, east, north = bounds

    transform = from_bounds(west, south, east, north, width, height)

    os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True)

    with rasterio.open(
        output_path,
        "w",
        driver="GTiff",
        height=height,
        width=width,
        count=bands,
        dtype=embeddings.dtype,
        crs=crs,
        transform=transform,
    ) as dst:
        dst.write(embeddings)

    logger.info(f"Saved {bands}-band GeoTIFF to {output_path}")
    return output_path

extract_patch_embeddings(dataset, max_samples=None, device=None)

Extract embeddings, coordinates, and timestamps from a patch-based dataset.

Iterates over a patch-based (NonGeoDataset) embedding dataset and collects all embedding vectors along with their spatial and temporal metadata into NumPy arrays.

Parameters:

Name Type Description Default
dataset Any

A patch-based embedding dataset (e.g., ClayEmbeddings).

required
max_samples Optional[int]

Maximum number of samples to extract. If None, extract all samples.

None
device Optional[str]

Device for tensor operations (unused, kept for API symmetry).

None

Returns:

Type Description
Dict[str, ndarray]

Dictionary with keys:

Dict[str, ndarray]
  • "embeddings": ndarray of shape (N, D)
Dict[str, ndarray]
  • "x": ndarray of shape (N,) with longitudes
Dict[str, ndarray]
  • "y": ndarray of shape (N,) with latitudes
Dict[str, ndarray]
  • "t": ndarray of shape (N,) with timestamps (if available)
Example

import geoai ds = geoai.load_embedding_dataset("clay", root="data.parquet") data = geoai.extract_patch_embeddings(ds, max_samples=1000) print(data["embeddings"].shape)

Source code in geoai/embeddings.py
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
def extract_patch_embeddings(
    dataset: Any,
    max_samples: Optional[int] = None,
    device: Optional[str] = None,
) -> Dict[str, np.ndarray]:
    """Extract embeddings, coordinates, and timestamps from a patch-based dataset.

    Iterates over a patch-based (NonGeoDataset) embedding dataset and
    collects all embedding vectors along with their spatial and temporal
    metadata into NumPy arrays.

    Args:
        dataset: A patch-based embedding dataset (e.g., ClayEmbeddings).
        max_samples: Maximum number of samples to extract. If None, extract
            all samples.
        device: Device for tensor operations (unused, kept for API symmetry).

    Returns:
        Dictionary with keys:

        - ``"embeddings"``: ndarray of shape ``(N, D)``
        - ``"x"``: ndarray of shape ``(N,)`` with longitudes
        - ``"y"``: ndarray of shape ``(N,)`` with latitudes
        - ``"t"``: ndarray of shape ``(N,)`` with timestamps (if available)

    Example:
        >>> import geoai
        >>> ds = geoai.load_embedding_dataset("clay", root="data.parquet")
        >>> data = geoai.extract_patch_embeddings(ds, max_samples=1000)
        >>> print(data["embeddings"].shape)
    """
    n = len(dataset)
    if max_samples is not None:
        n = min(n, max_samples)

    embeddings = []
    xs = []
    ys = []
    ts = []

    for i in range(n):
        sample = dataset[i]
        embeddings.append(sample["embedding"].numpy())
        if "x" in sample:
            xs.append(sample["x"].item())
        if "y" in sample:
            ys.append(sample["y"].item())
        if "t" in sample:
            ts.append(sample["t"].item())

    result = {"embeddings": np.stack(embeddings)}
    if xs:
        result["x"] = np.array(xs)
    if ys:
        result["y"] = np.array(ys)
    if ts:
        result["t"] = np.array(ts)

    return result

extract_pixel_embeddings(dataset, sampler=None, num_samples=100, size=256, flatten=True)

Extract embeddings from a pixel-based (RasterDataset) embedding dataset.

Uses a TorchGeo sampler to draw spatial patches and returns the embedding tensors. If flatten=True, pixels are reshaped to (N_total_pixels, D).

Parameters:

Name Type Description Default
dataset Any

A pixel-based embedding dataset (e.g., GoogleSatelliteEmbedding).

required
sampler Any

A torchgeo sampler instance. If None, a RandomGeoSampler is created with the given num_samples and size.

None
num_samples int

Number of random samples to draw (used only when sampler is None).

100
size float

Patch size in dataset CRS units (used only when sampler is None).

256
flatten bool

If True, flatten spatial dimensions so that the result has shape (N, D) where N is the total number of pixels.

True

Returns:

Type Description
Dict[str, Any]

Dictionary with keys:

Dict[str, Any]
  • "embeddings": ndarray of shape (N, D) if flattened, or list of arrays of shape (C, H, W)
Dict[str, Any]
  • "bounds": list of bounding box tensors for each sample
Example

import geoai ds = geoai.load_embedding_dataset( ... "google_satellite", paths="data/" ... ) data = geoai.extract_pixel_embeddings(ds, num_samples=50) print(data["embeddings"].shape)

Source code in geoai/embeddings.py
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
def extract_pixel_embeddings(
    dataset: Any,
    sampler: Any = None,
    num_samples: int = 100,
    size: float = 256,
    flatten: bool = True,
) -> Dict[str, Any]:
    """Extract embeddings from a pixel-based (RasterDataset) embedding dataset.

    Uses a TorchGeo sampler to draw spatial patches and returns the
    embedding tensors. If ``flatten=True``, pixels are reshaped to
    ``(N_total_pixels, D)``.

    Args:
        dataset: A pixel-based embedding dataset (e.g., GoogleSatelliteEmbedding).
        sampler: A torchgeo sampler instance. If None, a RandomGeoSampler is
            created with the given ``num_samples`` and ``size``.
        num_samples: Number of random samples to draw (used only when
            ``sampler`` is None).
        size: Patch size in dataset CRS units (used only when ``sampler``
            is None).
        flatten: If True, flatten spatial dimensions so that the result
            has shape ``(N, D)`` where N is the total number of pixels.

    Returns:
        Dictionary with keys:

        - ``"embeddings"``: ndarray of shape ``(N, D)`` if flattened, or
          list of arrays of shape ``(C, H, W)``
        - ``"bounds"``: list of bounding box tensors for each sample

    Example:
        >>> import geoai
        >>> ds = geoai.load_embedding_dataset(
        ...     "google_satellite", paths="data/"
        ... )
        >>> data = geoai.extract_pixel_embeddings(ds, num_samples=50)
        >>> print(data["embeddings"].shape)
    """
    from torchgeo.samplers import RandomGeoSampler

    if sampler is None:
        sampler = RandomGeoSampler(dataset, size=size, length=num_samples)

    all_embeddings = []
    all_bounds = []

    for query in sampler:
        sample = dataset[query]
        image = sample["image"]  # (C, H, W)
        all_bounds.append(sample.get("bounds"))

        if flatten:
            # (C, H, W) -> (H*W, C)
            c, h, w = image.shape
            pixels = image.permute(1, 2, 0).reshape(-1, c)
            all_embeddings.append(pixels.numpy())
        else:
            all_embeddings.append(image.numpy())

    result: Dict[str, Any] = {"bounds": all_bounds}
    if flatten:
        result["embeddings"] = np.concatenate(all_embeddings, axis=0)
    else:
        result["embeddings"] = all_embeddings

    return result

get_embedding_info(name)

Get detailed information about an embedding dataset.

Parameters:

Name Type Description Default
name str

Registry name of the dataset.

required

Returns:

Type Description
Dict[str, Any]

Dictionary with dataset metadata.

Raises:

Type Description
ValueError

If name is not found in the registry.

Example

import geoai info = geoai.get_embedding_info("google_satellite") print(info["description"])

Source code in geoai/embeddings.py
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
def get_embedding_info(name: str) -> Dict[str, Any]:
    """Get detailed information about an embedding dataset.

    Args:
        name: Registry name of the dataset.

    Returns:
        Dictionary with dataset metadata.

    Raises:
        ValueError: If name is not found in the registry.

    Example:
        >>> import geoai
        >>> info = geoai.get_embedding_info("google_satellite")
        >>> print(info["description"])
    """
    if name not in EMBEDDING_DATASETS:
        available = ", ".join(sorted(EMBEDDING_DATASETS.keys()))
        raise ValueError(f"Unknown embedding dataset '{name}'. Available: {available}")
    return EMBEDDING_DATASETS[name].copy()

list_embedding_datasets(kind=None, as_dataframe=True, verbose=True)

List all available embedding datasets from TorchGeo v0.9.0.

Parameters:

Name Type Description Default
kind Optional[str]

Filter by dataset kind. One of "patch" (NonGeoDataset) or "pixel" (RasterDataset). If None, list all datasets.

None
as_dataframe bool

If True, return a pandas DataFrame. Otherwise return a dictionary.

True
verbose bool

If True, print a summary table to the console.

True

Returns:

Type Description
Union[DataFrame, Dict[str, Dict[str, Any]]]

A DataFrame or dictionary describing the available datasets.

Example

import geoai df = geoai.list_embedding_datasets() df = geoai.list_embedding_datasets(kind="patch")

Source code in geoai/embeddings.py
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
def list_embedding_datasets(
    kind: Optional[str] = None,
    as_dataframe: bool = True,
    verbose: bool = True,
) -> Union[pd.DataFrame, Dict[str, Dict[str, Any]]]:
    """List all available embedding datasets from TorchGeo v0.9.0.

    Args:
        kind: Filter by dataset kind. One of ``"patch"`` (NonGeoDataset) or
            ``"pixel"`` (RasterDataset). If None, list all datasets.
        as_dataframe: If True, return a pandas DataFrame. Otherwise return
            a dictionary.
        verbose: If True, print a summary table to the console.

    Returns:
        A DataFrame or dictionary describing the available datasets.

    Example:
        >>> import geoai
        >>> df = geoai.list_embedding_datasets()
        >>> df = geoai.list_embedding_datasets(kind="patch")
    """
    datasets = EMBEDDING_DATASETS
    if kind is not None:
        kind = kind.lower()
        if kind not in ("patch", "pixel"):
            raise ValueError(f"kind must be 'patch' or 'pixel', got '{kind}'")
        datasets = {k: v for k, v in datasets.items() if v["kind"] == kind}

    if not as_dataframe:
        return datasets

    rows = []
    for name, info in datasets.items():
        rows.append(
            {
                "name": name,
                "class": info["class_name"],
                "kind": info["kind"],
                "spatial_extent": info["spatial_extent"],
                "resolution": info["spatial_resolution"],
                "temporal_extent": info["temporal_extent"],
                "dimensions": info["dimensions"],
                "dtype": info["dtype"],
                "license": info["license"],
            }
        )
    df = pd.DataFrame(rows)
    if verbose:
        print(df.to_string(index=False))
    return df

load_embedding_dataset(name, root=None, paths=None, transforms=None, **kwargs)

Load an embedding dataset by name.

This is a unified factory function that instantiates the correct torchgeo embedding dataset class based on the name.

For patch-based datasets (kind="patch"), pass the root parameter pointing to a GeoParquet file. For pixel-based datasets (kind="pixel"), pass the paths parameter pointing to a directory of GeoTIFF files.

Parameters:

Name Type Description Default
name str

Registry name of the dataset (e.g., "clay", "google_satellite"). Use :func:list_embedding_datasets to see all available names.

required
root Optional[str]

Root directory or file path (used by patch-based datasets).

None
paths Optional[Union[str, List[str]]]

One or more directories containing GeoTIFF files (used by pixel-based datasets).

None
transforms Optional[Callable]

Optional transform function applied to each sample.

None
**kwargs Any

Additional keyword arguments passed to the dataset constructor (e.g., crs, res, bands, cache, download, time_series).

{}

Returns:

Type Description
Any

A torchgeo dataset instance.

Raises:

Type Description
ValueError

If required arguments are missing or name is unknown.

Example

import geoai

Patch-based dataset

ds = geoai.load_embedding_dataset( ... "clay", root="path/to/clay_embeddings.parquet" ... ) sample = ds[0] print(sample["embedding"].shape)

Pixel-based dataset

ds = geoai.load_embedding_dataset( ... "google_satellite", paths="path/to/geotiffs/" ... )

Source code in geoai/embeddings.py
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
def load_embedding_dataset(
    name: str,
    root: Optional[str] = None,
    paths: Optional[Union[str, List[str]]] = None,
    transforms: Optional[Callable] = None,
    **kwargs: Any,
) -> Any:
    """Load an embedding dataset by name.

    This is a unified factory function that instantiates the correct
    torchgeo embedding dataset class based on the name.

    For **patch-based** datasets (``kind="patch"``), pass the ``root``
    parameter pointing to a GeoParquet file. For **pixel-based** datasets
    (``kind="pixel"``), pass the ``paths`` parameter pointing to a
    directory of GeoTIFF files.

    Args:
        name: Registry name of the dataset (e.g., ``"clay"``,
            ``"google_satellite"``). Use :func:`list_embedding_datasets` to
            see all available names.
        root: Root directory or file path (used by patch-based datasets).
        paths: One or more directories containing GeoTIFF files (used by
            pixel-based datasets).
        transforms: Optional transform function applied to each sample.
        **kwargs: Additional keyword arguments passed to the dataset
            constructor (e.g., ``crs``, ``res``, ``bands``, ``cache``,
            ``download``, ``time_series``).

    Returns:
        A torchgeo dataset instance.

    Raises:
        ValueError: If required arguments are missing or name is unknown.

    Example:
        >>> import geoai
        >>> # Patch-based dataset
        >>> ds = geoai.load_embedding_dataset(
        ...     "clay", root="path/to/clay_embeddings.parquet"
        ... )
        >>> sample = ds[0]
        >>> print(sample["embedding"].shape)
        >>> # Pixel-based dataset
        >>> ds = geoai.load_embedding_dataset(
        ...     "google_satellite", paths="path/to/geotiffs/"
        ... )
    """
    if name not in EMBEDDING_DATASETS:
        available = ", ".join(sorted(EMBEDDING_DATASETS.keys()))
        raise ValueError(f"Unknown embedding dataset '{name}'. Available: {available}")

    info = EMBEDDING_DATASETS[name]
    cls = _get_dataset_class(name)
    kind = info["kind"]

    if kind == "patch":
        if root is None:
            raise ValueError(
                f"Patch-based dataset '{name}' requires the 'root' parameter "
                "pointing to a GeoParquet file."
            )
        return cls(root=root, transforms=transforms, **kwargs)
    else:
        if paths is None and root is not None:
            paths = root
        if paths is None:
            raise ValueError(
                f"Pixel-based dataset '{name}' requires the 'paths' parameter "
                "pointing to a directory of GeoTIFF files."
            )
        return cls(paths=paths, transforms=transforms, **kwargs)

plot_embedding_raster(image, method='pca', figsize=(8, 8), title='Embedding Visualization', save_path=None)

Visualize a pixel-based embedding raster using PCA to create an RGB image.

Projects high-dimensional embedding bands into 3 principal components for RGB visualization.

Parameters:

Name Type Description Default
image Union[ndarray, Tensor]

Embedding tensor of shape (C, H, W) or (H, W, C).

required
method str

Projection method. Currently only "pca" is supported.

'pca'
figsize Tuple[int, int]

Figure size.

(8, 8)
title Optional[str]

Plot title.

'Embedding Visualization'
save_path Optional[str]

If provided, save the figure to this path.

None

Returns:

Type Description
Figure

A matplotlib Figure.

Example

import geoai ds = geoai.load_embedding_dataset( ... "google_satellite", paths="data/" ... ) from torchgeo.samplers import RandomGeoSampler sampler = RandomGeoSampler(ds, size=256, length=1) sample = ds[next(iter(sampler))] fig = geoai.plot_embedding_raster(sample["image"])

Source code in geoai/embeddings.py
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
def plot_embedding_raster(
    image: Union[np.ndarray, torch.Tensor],
    method: str = "pca",
    figsize: Tuple[int, int] = (8, 8),
    title: Optional[str] = "Embedding Visualization",
    save_path: Optional[str] = None,
) -> plt.Figure:
    """Visualize a pixel-based embedding raster using PCA to create an RGB image.

    Projects high-dimensional embedding bands into 3 principal components
    for RGB visualization.

    Args:
        image: Embedding tensor of shape ``(C, H, W)`` or ``(H, W, C)``.
        method: Projection method. Currently only ``"pca"`` is supported.
        figsize: Figure size.
        title: Plot title.
        save_path: If provided, save the figure to this path.

    Returns:
        A matplotlib Figure.

    Example:
        >>> import geoai
        >>> ds = geoai.load_embedding_dataset(
        ...     "google_satellite", paths="data/"
        ... )
        >>> from torchgeo.samplers import RandomGeoSampler
        >>> sampler = RandomGeoSampler(ds, size=256, length=1)
        >>> sample = ds[next(iter(sampler))]
        >>> fig = geoai.plot_embedding_raster(sample["image"])
    """
    if isinstance(image, torch.Tensor):
        image = image.numpy()

    # Ensure (C, H, W) format
    if image.ndim == 3 and image.shape[2] < image.shape[0]:
        # Likely (H, W, C) format
        image = np.transpose(image, (2, 0, 1))

    c, h, w = image.shape
    # Reshape to (H*W, C) for PCA
    pixels = image.reshape(c, -1).T  # (H*W, C)

    from sklearn.decomposition import PCA

    pca = PCA(n_components=3)
    rgb = pca.fit_transform(pixels)

    # Normalize to [0, 1]
    rgb -= rgb.min(axis=0, keepdims=True)
    maxvals = rgb.max(axis=0, keepdims=True)
    maxvals[maxvals == 0] = 1
    rgb /= maxvals
    rgb = rgb.reshape(h, w, 3)

    fig, ax = plt.subplots(figsize=figsize)
    ax.imshow(rgb)
    ax.axis("off")
    if title:
        ax.set_title(title)
    fig.tight_layout()

    if save_path:
        fig.savefig(save_path, dpi=150, bbox_inches="tight")

    return fig

plot_embedding_vector(embedding, title='Embedding Vector', figsize=(10, 3), save_path=None)

Plot a single embedding vector as a line chart.

Parameters:

Name Type Description Default
embedding Union[ndarray, Tensor]

1D array or tensor of shape (D,).

required
title Optional[str]

Plot title.

'Embedding Vector'
figsize Tuple[int, int]

Figure size.

(10, 3)
save_path Optional[str]

If provided, save the figure to this path.

None

Returns:

Type Description
Figure

A matplotlib Figure.

Example

import geoai ds = geoai.load_embedding_dataset("clay", root="data.parquet") sample = ds[0] fig = geoai.plot_embedding_vector(sample["embedding"])

Source code in geoai/embeddings.py
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
def plot_embedding_vector(
    embedding: Union[np.ndarray, torch.Tensor],
    title: Optional[str] = "Embedding Vector",
    figsize: Tuple[int, int] = (10, 3),
    save_path: Optional[str] = None,
) -> plt.Figure:
    """Plot a single embedding vector as a line chart.

    Args:
        embedding: 1D array or tensor of shape ``(D,)``.
        title: Plot title.
        figsize: Figure size.
        save_path: If provided, save the figure to this path.

    Returns:
        A matplotlib Figure.

    Example:
        >>> import geoai
        >>> ds = geoai.load_embedding_dataset("clay", root="data.parquet")
        >>> sample = ds[0]
        >>> fig = geoai.plot_embedding_vector(sample["embedding"])
    """
    if isinstance(embedding, torch.Tensor):
        embedding = embedding.numpy()

    fig, ax = plt.subplots(figsize=figsize)
    ax.plot(embedding)
    ax.set_xlabel("Dimension")
    ax.set_ylabel("Value")
    if title:
        ax.set_title(title)
    fig.tight_layout()

    if save_path:
        fig.savefig(save_path, dpi=150, bbox_inches="tight")

    return fig

train_embedding_classifier(train_embeddings, train_labels, val_embeddings=None, val_labels=None, method='knn', label_names=None, verbose=True, **kwargs)

Train a lightweight classifier on pre-computed embeddings.

This function trains a simple classifier (k-NN, Random Forest, or Logistic Regression) on embedding vectors, providing a quick baseline without requiring GPU fine-tuning.

Parameters:

Name Type Description Default
train_embeddings ndarray

Training embeddings of shape (N_train, D).

required
train_labels ndarray

Training labels of shape (N_train,).

required
val_embeddings Optional[ndarray]

Optional validation embeddings of shape (N_val, D).

None
val_labels Optional[ndarray]

Optional validation labels of shape (N_val,).

None
method str

Classifier type. One of "knn", "random_forest", or "logistic_regression".

'knn'
label_names Optional[List[str]]

Optional list of human-readable class names.

None
verbose bool

If True, print classification report for validation set.

True
**kwargs Any

Additional keyword arguments passed to the classifier (e.g., n_neighbors for k-NN).

{}

Returns:

Type Description
Dict[str, Any]

Dictionary with keys:

Dict[str, Any]
  • "model": the fitted classifier
Dict[str, Any]
  • "train_accuracy": training accuracy
Dict[str, Any]
  • "val_accuracy": validation accuracy (if val data provided)
Dict[str, Any]
  • "val_predictions": predictions on validation set
Dict[str, Any]
  • "classification_report": string report (if val data provided)
Example

import geoai result = geoai.train_embedding_classifier( ... train_embeddings, train_labels, ... val_embeddings, val_labels, ... method="knn", n_neighbors=5, ... ) print(f"Validation accuracy: {result['val_accuracy']:.2%}")

Source code in geoai/embeddings.py
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
def train_embedding_classifier(
    train_embeddings: np.ndarray,
    train_labels: np.ndarray,
    val_embeddings: Optional[np.ndarray] = None,
    val_labels: Optional[np.ndarray] = None,
    method: str = "knn",
    label_names: Optional[List[str]] = None,
    verbose: bool = True,
    **kwargs: Any,
) -> Dict[str, Any]:
    """Train a lightweight classifier on pre-computed embeddings.

    This function trains a simple classifier (k-NN, Random Forest, or
    Logistic Regression) on embedding vectors, providing a quick baseline
    without requiring GPU fine-tuning.

    Args:
        train_embeddings: Training embeddings of shape ``(N_train, D)``.
        train_labels: Training labels of shape ``(N_train,)``.
        val_embeddings: Optional validation embeddings of shape ``(N_val, D)``.
        val_labels: Optional validation labels of shape ``(N_val,)``.
        method: Classifier type. One of ``"knn"``, ``"random_forest"``,
            or ``"logistic_regression"``.
        label_names: Optional list of human-readable class names.
        verbose: If True, print classification report for validation set.
        **kwargs: Additional keyword arguments passed to the classifier
            (e.g., ``n_neighbors`` for k-NN).

    Returns:
        Dictionary with keys:

        - ``"model"``: the fitted classifier
        - ``"train_accuracy"``: training accuracy
        - ``"val_accuracy"``: validation accuracy (if val data provided)
        - ``"val_predictions"``: predictions on validation set
        - ``"classification_report"``: string report (if val data provided)

    Example:
        >>> import geoai
        >>> result = geoai.train_embedding_classifier(
        ...     train_embeddings, train_labels,
        ...     val_embeddings, val_labels,
        ...     method="knn", n_neighbors=5,
        ... )
        >>> print(f"Validation accuracy: {result['val_accuracy']:.2%}")
    """
    method = method.lower()

    if method == "knn":
        from sklearn.neighbors import KNeighborsClassifier

        n_neighbors = kwargs.pop("n_neighbors", 5)
        model = KNeighborsClassifier(n_neighbors=n_neighbors, **kwargs)
    elif method == "random_forest":
        from sklearn.ensemble import RandomForestClassifier

        n_estimators = kwargs.pop("n_estimators", 100)
        model = RandomForestClassifier(
            n_estimators=n_estimators, random_state=42, **kwargs
        )
    elif method == "logistic_regression":
        from sklearn.linear_model import LogisticRegression

        model = LogisticRegression(max_iter=1000, random_state=42, **kwargs)
    else:
        raise ValueError(
            f"method must be 'knn', 'random_forest', or "
            f"'logistic_regression', got '{method}'"
        )

    model.fit(train_embeddings, train_labels)

    result: Dict[str, Any] = {
        "model": model,
        "train_accuracy": model.score(train_embeddings, train_labels),
    }

    if val_embeddings is not None and val_labels is not None:
        val_preds = model.predict(val_embeddings)
        result["val_predictions"] = val_preds
        result["val_accuracy"] = model.score(val_embeddings, val_labels)

        from sklearn.metrics import classification_report

        report = classification_report(
            val_labels,
            val_preds,
            target_names=label_names,
            digits=3,
            zero_division=0,
        )
        result["classification_report"] = report

        if verbose:
            print(report)

    return result

visualize_embeddings(embeddings, labels=None, label_names=None, method='pca', n_components=2, figsize=(8, 8), cmap='tab10', alpha=0.6, s=5, title=None, save_path=None, **kwargs)

Visualize high-dimensional embeddings in 2D using dimensionality reduction.

Supports PCA, t-SNE, and UMAP for projecting embedding vectors into a 2D scatter plot.

Parameters:

Name Type Description Default
embeddings ndarray

Array of shape (N, D) containing embedding vectors.

required
labels Optional[ndarray]

Optional integer labels of shape (N,) for coloring points.

None
label_names Optional[List[str]]

Optional list of label names for the legend.

None
method str

Dimensionality reduction method. One of "pca", "tsne", or "umap".

'pca'
n_components int

Number of components for the projection (2 or 3).

2
figsize Tuple[int, int]

Figure size in inches.

(8, 8)
cmap str

Matplotlib colormap name.

'tab10'
alpha float

Point transparency.

0.6
s int

Point size.

5
title Optional[str]

Plot title. If None, an automatic title is generated.

None
save_path Optional[str]

If provided, save the figure to this path.

None
**kwargs Any

Additional keyword arguments passed to the reducer (e.g., perplexity for t-SNE, n_neighbors for UMAP).

{}

Returns:

Type Description
Figure

A matplotlib Figure.

Raises:

Type Description
ValueError

If an unsupported method is specified.

ImportError

If required libraries are not installed.

Example

import geoai fig = geoai.visualize_embeddings( ... embeddings, labels=labels, method="pca" ... )

Source code in geoai/embeddings.py
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
def visualize_embeddings(
    embeddings: np.ndarray,
    labels: Optional[np.ndarray] = None,
    label_names: Optional[List[str]] = None,
    method: str = "pca",
    n_components: int = 2,
    figsize: Tuple[int, int] = (8, 8),
    cmap: str = "tab10",
    alpha: float = 0.6,
    s: int = 5,
    title: Optional[str] = None,
    save_path: Optional[str] = None,
    **kwargs: Any,
) -> plt.Figure:
    """Visualize high-dimensional embeddings in 2D using dimensionality reduction.

    Supports PCA, t-SNE, and UMAP for projecting embedding vectors into a
    2D scatter plot.

    Args:
        embeddings: Array of shape ``(N, D)`` containing embedding vectors.
        labels: Optional integer labels of shape ``(N,)`` for coloring points.
        label_names: Optional list of label names for the legend.
        method: Dimensionality reduction method. One of ``"pca"``, ``"tsne"``,
            or ``"umap"``.
        n_components: Number of components for the projection (2 or 3).
        figsize: Figure size in inches.
        cmap: Matplotlib colormap name.
        alpha: Point transparency.
        s: Point size.
        title: Plot title. If None, an automatic title is generated.
        save_path: If provided, save the figure to this path.
        **kwargs: Additional keyword arguments passed to the reducer
            (e.g., ``perplexity`` for t-SNE, ``n_neighbors`` for UMAP).

    Returns:
        A matplotlib Figure.

    Raises:
        ValueError: If an unsupported method is specified.
        ImportError: If required libraries are not installed.

    Example:
        >>> import geoai
        >>> fig = geoai.visualize_embeddings(
        ...     embeddings, labels=labels, method="pca"
        ... )
    """
    method = method.lower()
    if method not in ("pca", "tsne", "umap"):
        raise ValueError(f"method must be 'pca', 'tsne', or 'umap', got '{method}'")

    # Dimensionality reduction
    if method == "pca":
        from sklearn.decomposition import PCA

        reducer = PCA(n_components=n_components, whiten=True, **kwargs)
        reduced = reducer.fit_transform(embeddings)
        explained = reducer.explained_variance_ratio_.sum()
        default_title = f"PCA of Embeddings (explained variance: {explained:.1%})"
    elif method == "tsne":
        from sklearn.manifold import TSNE

        tsne_kwargs = {"n_components": n_components, "random_state": 42}
        tsne_kwargs.update(kwargs)
        reducer = TSNE(**tsne_kwargs)
        reduced = reducer.fit_transform(embeddings)
        default_title = "t-SNE of Embeddings"
    else:  # umap
        try:
            import umap

            umap_kwargs = {"n_components": n_components, "random_state": 42}
            umap_kwargs.update(kwargs)
            reducer = umap.UMAP(**umap_kwargs)
            reduced = reducer.fit_transform(embeddings)
            default_title = "UMAP of Embeddings"
        except ImportError:
            raise ImportError(
                "umap-learn is required for UMAP visualization. "
                "Install with: pip install umap-learn"
            )

    title = title or default_title

    fig, ax = plt.subplots(figsize=figsize)
    scatter_kwargs = dict(s=s, alpha=alpha)
    if labels is not None:
        scatter_kwargs["c"] = labels
        scatter_kwargs["cmap"] = cmap
    scatter = ax.scatter(
        reduced[:, 0],
        reduced[:, 1],
        **scatter_kwargs,
    )

    if labels is not None:
        handles, _ = scatter.legend_elements()
        if label_names is not None:
            ax.legend(handles, label_names, title="Classes", loc="best")
        else:
            ax.legend(*scatter.legend_elements(), title="Classes", loc="best")

    ax.set_title(title)
    ax.set_xlabel("Component 1")
    ax.set_ylabel("Component 2")
    fig.tight_layout()

    if save_path:
        fig.savefig(save_path, dpi=150, bbox_inches="tight")
        logger.info(f"Figure saved to {save_path}")

    return fig