Skip to content

map_widgets module

Interactive widget for GeoAI.

DINOv3GUI

Bases: VBox

Interactive widget for DINOv3.

Source code in geoai/map_widgets.py
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
class DINOv3GUI(widgets.VBox):
    """Interactive widget for DINOv3."""

    def __init__(
        self,
        raster: str,
        processor=None,
        features=None,
        host_map=None,
        position="topright",
        colormap_options=None,
        raster_args=None,
    ):
        """Initialize the DINOv3 GUI.

        Args:
            raster (str): The path to the raster image.
            processor (DINOv3GeoProcessor): The DINOv3 processor.
            features (torch.Tensor): The features of the raster image.
            host_map (Map): The host map.
            position (str): The position of the widget.
            colormap_options (list): The colormap options.
            raster_args (dict): The raster arguments.

        Returns:
            None

        Example:
            >>> processor = DINOv3GeoProcessor()
            >>> features, h_patches, w_patches = processor.extract_features(raster)
            >>> gui = DINOv3GUI(raster, processor, features, host_map=m)
        """
        if raster_args is None:
            raster_args = {}

        if "layer_name" not in raster_args:
            raster_args["layer_name"] = "Raster"

        if colormap_options is None:
            colormap_options = [
                "jet",
                "viridis",
                "plasma",
                "inferno",
                "magma",
                "cividis",
            ]

        main_widget = widgets.VBox(layout=widgets.Layout(width="230px"))
        style = {"description_width": "initial"}
        layout = widgets.Layout(width="95%", padding="0px 5px 0px 5px")

        interpolation_checkbox = widgets.Checkbox(
            value=True,
            description="Use interpolation",
            style=style,
            layout=layout,
        )

        threshold_slider = widgets.FloatSlider(
            value=0.7,
            min=0,
            max=1,
            step=0.01,
            description="Threshold",
            style=style,
            layout=layout,
        )

        opacity_slider = widgets.FloatSlider(
            value=0.5,
            min=0,
            max=1,
            step=0.01,
            description="Opacity",
            style=style,
            layout=layout,
        )
        colormap_dropdown = widgets.Dropdown(
            options=colormap_options,
            value="jet",
            description="Colormap",
            style=style,
            layout=layout,
        )
        layer_name_input = widgets.Text(
            value="Similarity",
            description="Layer name",
            style=style,
            layout=layout,
        )

        save_button = widgets.Button(
            description="Save",
        )

        reset_button = widgets.Button(
            description="Reset",
        )

        output = widgets.Output()

        main_widget.children = [
            interpolation_checkbox,
            threshold_slider,
            opacity_slider,
            colormap_dropdown,
            layer_name_input,
            widgets.HBox([save_button, reset_button]),
            output,
        ]

        if host_map is not None:

            host_map.add_widget(main_widget, add_header=True, position=position)

            if raster is not None:
                host_map.add_raster(raster, **raster_args)

            def handle_map_interaction(**kwargs):
                try:
                    if kwargs.get("type") == "click":
                        latlon = kwargs.get("coordinates")
                        with output:
                            output.clear_output()

                            results = processor.compute_similarity(
                                source=raster,
                                features=features,
                                query_coords=latlon[::-1],
                                output_dir="dinov3_results",
                                use_interpolation=interpolation_checkbox.value,
                                coord_crs="EPSG:4326",
                            )
                            array = results["image_dict"]["image"]
                            binary_array = array > threshold_slider.value
                            image = dict_to_image(results["image_dict"])
                            binary_image = dict_to_image(
                                {
                                    "image": binary_array,
                                    "crs": results["image_dict"]["crs"],
                                    "bounds": results["image_dict"]["bounds"],
                                }
                            )
                            host_map.add_raster(
                                image,
                                colormap=colormap_dropdown.value,
                                opacity=opacity_slider.value,
                                layer_name=layer_name_input.value,
                                zoom_to_layer=False,
                                overwrite=True,
                            )
                            host_map.add_raster(
                                binary_image,
                                colormap="jet",
                                nodata=0,
                                opacity=opacity_slider.value,
                                layer_name="Foreground",
                                zoom_to_layer=False,
                                overwrite=True,
                                visible=False,
                            )
                except Exception as e:
                    with output:
                        print(e)

            host_map.on_interaction(handle_map_interaction)
            host_map.default_style = {"cursor": "crosshair"}

__init__(raster, processor=None, features=None, host_map=None, position='topright', colormap_options=None, raster_args=None)

Initialize the DINOv3 GUI.

Parameters:

Name Type Description Default
raster str

The path to the raster image.

required
processor DINOv3GeoProcessor

The DINOv3 processor.

None
features Tensor

The features of the raster image.

None
host_map Map

The host map.

None
position str

The position of the widget.

'topright'
colormap_options list

The colormap options.

None
raster_args dict

The raster arguments.

None

Returns:

Type Description

None

Example

processor = DINOv3GeoProcessor() features, h_patches, w_patches = processor.extract_features(raster) gui = DINOv3GUI(raster, processor, features, host_map=m)

Source code in geoai/map_widgets.py
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
def __init__(
    self,
    raster: str,
    processor=None,
    features=None,
    host_map=None,
    position="topright",
    colormap_options=None,
    raster_args=None,
):
    """Initialize the DINOv3 GUI.

    Args:
        raster (str): The path to the raster image.
        processor (DINOv3GeoProcessor): The DINOv3 processor.
        features (torch.Tensor): The features of the raster image.
        host_map (Map): The host map.
        position (str): The position of the widget.
        colormap_options (list): The colormap options.
        raster_args (dict): The raster arguments.

    Returns:
        None

    Example:
        >>> processor = DINOv3GeoProcessor()
        >>> features, h_patches, w_patches = processor.extract_features(raster)
        >>> gui = DINOv3GUI(raster, processor, features, host_map=m)
    """
    if raster_args is None:
        raster_args = {}

    if "layer_name" not in raster_args:
        raster_args["layer_name"] = "Raster"

    if colormap_options is None:
        colormap_options = [
            "jet",
            "viridis",
            "plasma",
            "inferno",
            "magma",
            "cividis",
        ]

    main_widget = widgets.VBox(layout=widgets.Layout(width="230px"))
    style = {"description_width": "initial"}
    layout = widgets.Layout(width="95%", padding="0px 5px 0px 5px")

    interpolation_checkbox = widgets.Checkbox(
        value=True,
        description="Use interpolation",
        style=style,
        layout=layout,
    )

    threshold_slider = widgets.FloatSlider(
        value=0.7,
        min=0,
        max=1,
        step=0.01,
        description="Threshold",
        style=style,
        layout=layout,
    )

    opacity_slider = widgets.FloatSlider(
        value=0.5,
        min=0,
        max=1,
        step=0.01,
        description="Opacity",
        style=style,
        layout=layout,
    )
    colormap_dropdown = widgets.Dropdown(
        options=colormap_options,
        value="jet",
        description="Colormap",
        style=style,
        layout=layout,
    )
    layer_name_input = widgets.Text(
        value="Similarity",
        description="Layer name",
        style=style,
        layout=layout,
    )

    save_button = widgets.Button(
        description="Save",
    )

    reset_button = widgets.Button(
        description="Reset",
    )

    output = widgets.Output()

    main_widget.children = [
        interpolation_checkbox,
        threshold_slider,
        opacity_slider,
        colormap_dropdown,
        layer_name_input,
        widgets.HBox([save_button, reset_button]),
        output,
    ]

    if host_map is not None:

        host_map.add_widget(main_widget, add_header=True, position=position)

        if raster is not None:
            host_map.add_raster(raster, **raster_args)

        def handle_map_interaction(**kwargs):
            try:
                if kwargs.get("type") == "click":
                    latlon = kwargs.get("coordinates")
                    with output:
                        output.clear_output()

                        results = processor.compute_similarity(
                            source=raster,
                            features=features,
                            query_coords=latlon[::-1],
                            output_dir="dinov3_results",
                            use_interpolation=interpolation_checkbox.value,
                            coord_crs="EPSG:4326",
                        )
                        array = results["image_dict"]["image"]
                        binary_array = array > threshold_slider.value
                        image = dict_to_image(results["image_dict"])
                        binary_image = dict_to_image(
                            {
                                "image": binary_array,
                                "crs": results["image_dict"]["crs"],
                                "bounds": results["image_dict"]["bounds"],
                            }
                        )
                        host_map.add_raster(
                            image,
                            colormap=colormap_dropdown.value,
                            opacity=opacity_slider.value,
                            layer_name=layer_name_input.value,
                            zoom_to_layer=False,
                            overwrite=True,
                        )
                        host_map.add_raster(
                            binary_image,
                            colormap="jet",
                            nodata=0,
                            opacity=opacity_slider.value,
                            layer_name="Foreground",
                            zoom_to_layer=False,
                            overwrite=True,
                            visible=False,
                        )
            except Exception as e:
                with output:
                    print(e)

        host_map.on_interaction(handle_map_interaction)
        host_map.default_style = {"cursor": "crosshair"}