Skip to content

metrics

sleap.gui.dialogs.metrics

Dialog/widgets for showing metrics on trained models.

Classes:

Name Description
DetailedMetricsDialog

Dialog to show detailed metrics for a trained model.

MetricsTableDialog

Dialog for showing table with multiple models.

MetricsTableModel

Model (i.e. Qt model/view) for table in MetricsTableDialog.

DetailedMetricsDialog

Bases: QWidget

Dialog to show detailed metrics for a trained model.

Parameters:

Name Type Description Default
cfg_info ConfigFileInfo

The ConfigFileInfo object (from TrainingConfigsGetter) for the model we want to show.

required
Source code in sleap/gui/dialogs/metrics.py
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
class DetailedMetricsDialog(QtWidgets.QWidget):
    """
    Dialog to show detailed metrics for a trained model.

    Args:
        cfg_info: The `ConfigFileInfo` object (from `TrainingConfigsGetter`)
            for the model we want to show.
    """

    def __init__(self, cfg_info: ConfigFileInfo):
        super(DetailedMetricsDialog, self).__init__()

        self.setWindowTitle(cfg_info.path_dir)
        self.setMinimumWidth(800)

        self.cfg_info = cfg_info
        self.skeleton = cfg_info.skeleton

        self.metrics = self.cfg_info.metrics

        layout = QtWidgets.QHBoxLayout()
        metrics_layout = QtWidgets.QFormLayout()

        if self.metrics:
            for key, val in self.metrics.items():
                if (
                    isinstance(val, np.float64)
                    or isinstance(val, np.int64)
                    or isinstance(val, np.ndarray)
                    and not len(val.shape)
                ):
                    val_str = str(val)

                    key_str = (
                        METRICS_KEY_LABELS[key] if key in METRICS_KEY_LABELS else key
                    )

                    text_widget = QtWidgets.QLabel(val_str)
                    text_widget.setTextInteractionFlags(QtCore.Qt.TextSelectableByMouse)
                    metrics_layout.addRow(f"<b>{key_str}</b>:", text_widget)

            metrics_widget = QtWidgets.QWidget()
            metrics_widget.setLayout(metrics_layout)

            self.canvas = MplCanvas(dpi=50)

            layout.addWidget(metrics_widget)
            layout.addWidget(self.canvas)

            self._plot_distances()
        else:
            text_widget = QtWidgets.QLabel(
                "Metrics have not been generated for this model."
            )
            layout.addWidget(text_widget)

        self.setLayout(layout)

    def _plot_distances(self):
        """Plots node distances (using matplotlib widget)."""
        ax = self.canvas.axes

        node_names = self.skeleton.node_names if self.skeleton else None

        dists = pd.DataFrame(self.metrics["dist.dists"], columns=node_names).melt(
            var_name="Part", value_name="Error"
        )

        sns.boxplot(data=dists, x="Error", y="Part", fliersize=0, ax=ax)

        sns.stripplot(
            data=dists, x="Error", y="Part", alpha=0.25, linewidth=1, jitter=0.2, ax=ax
        )

        ax.set_title("Node distances (ground truth vs prediction)")
        dist_1d = self.metrics["dist.dists"].flatten()

        xmax = np.ceil(np.ceil(np.nanpercentile(dist_1d, 95) / 5) + 1) * 5
        ax.set_xlim([0, xmax])
        ax.set_xlabel("Error (px)")

    def _plot_oks(self):
        """Plots OKS -- not currently used."""
        ax = self.canvas.axes
        metrics = self.metrics

        for match_threshold, precision in zip(
            metrics["oks_voc.match_score_thresholds"], metrics["oks_voc.precisions"]
        ):
            ax.plot(
                metrics["oks_voc.recall_thresholds"],
                precision,
                "-",
                label=f"OKS @ {match_threshold:.2f}",
            )
        ax.set_xlabel("Recall")
        ax.set_ylabel("Precision")

MetricsTableDialog

Bases: QWidget

Dialog for showing table with multiple models.

The dialog is can show multiple models, including those which don't already have metrics (ideally you'd be able to generate evals but this isn't currently supported).

You can then view details on the models (hyperparameters or more detailed metrics).

The typical use-case is to init dialog with path to labels file, and it will then show all trained models found within subdirectories.

Source code in sleap/gui/dialogs/metrics.py
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
class MetricsTableDialog(QtWidgets.QWidget):
    """Dialog for showing table with multiple models.

    The dialog is can show multiple models, including those which don't
    already have metrics (ideally you'd be able to generate evals but this isn't
    currently supported).

    You can then view details on the models (hyperparameters or more detailed
    metrics).

    The typical use-case is to init dialog with path to labels file, and it
    will then show all trained models found within subdirectories.
    """

    def __init__(self, labels_filename: Text = ""):
        super(MetricsTableDialog, self).__init__()

        labels_filename = labels_filename or ""

        self._cfg_getter = TrainingConfigsGetter.make_from_labels_filename(
            labels_filename,
        )
        self._cfg_getter.search_depth = 4

        self.table_model = MetricsTableModel(items=[])
        self.table_view = GenericTableView(
            model=self.table_model, is_activatable=True, row_name="trained_model"
        )
        self.table_view.state.connect("trained_model", self._show_metric_details)
        self.table_view.state.connect("selected_trained_model", self._update_gui)

        button_layout = QtWidgets.QHBoxLayout()
        buttons = QtWidgets.QWidget()
        buttons.setLayout(button_layout)

        btn = QtWidgets.QPushButton("Add Trained Model(s)")
        btn.clicked.connect(self._add_model_action)
        button_layout.addWidget(btn)

        btn = QtWidgets.QPushButton("View Hyperparameters")
        btn.clicked.connect(lambda: self._show_model_params())
        button_layout.addWidget(btn)
        self._view_model_btn = btn

        btn = QtWidgets.QPushButton("View Metrics")
        btn.clicked.connect(lambda: self._show_metric_details())
        button_layout.addWidget(btn)
        self._view_metrics_btn = btn

        layout = QtWidgets.QVBoxLayout()
        layout.addWidget(self.table_view)
        layout.addWidget(buttons)
        self.setLayout(layout)

        self.setWindowTitle("Metrics for Trained Models")

        self._update_cfgs()
        self._update_gui()

        self.setMinimumWidth(1200)

    def _update_gui(self, *args):
        """Enables/disables buttons as appropriate for table row selection."""
        is_selected = self.table_view.state["selected_trained_model"] is not None
        self._view_model_btn.setEnabled(is_selected)
        self._view_metrics_btn.setEnabled(is_selected)

    def _update_cfgs(self):
        """Searches for models and updates table."""
        self._cfg_getter.update()
        cfgs = self._cfg_getter.get_filtered_configs(only_trained=True)
        self.table_model.items = cfgs
        self.table_view.resizeColumnsToContents()

    def _add_model_action(self):
        """Method called when user clicks 'add models' button."""
        dir = FileDialog.openDir(None, dir=None, caption="")

        if dir:
            self._cfg_getter.dir_paths.append(dir)
            self._update_cfgs()

    def _show_model(self, cfg_info: Optional[ConfigFileInfo] = None):
        """Method to show both hyperparam and metrics windows."""
        self._show_model_params(cfg_info)
        self._show_metric_details(cfg_info)

    def _show_model_params(
        self, cfg_info: Optional[ConfigFileInfo] = None, model_detail_widgets=dict()
    ):
        """
        Method to show dialog with hyperparameters for model.

        Args:
            cfg_info: The `ConfigFileInfo` for the model to show; if None,
                then show for model currently selected in table.
            model_detail_widgets: Not user param; cache for widgets so that we
                don't create new window if user views same model twice.
        """
        if cfg_info is None:
            cfg_info = self.table_view.getSelectedRowItem()

        cfg_getter = self._cfg_getter
        key = cfg_info.path
        if key not in model_detail_widgets:
            model_detail_widgets[key] = TrainingEditorWidget.from_trained_config(
                cfg_info, cfg_getter
            )

        model_detail_widgets[key].show()
        model_detail_widgets[key].raise_()
        model_detail_widgets[key].activateWindow()

    def _show_metric_details(
        self, cfg_info: Optional[ConfigFileInfo] = None, metric_detail_widgets=dict()
    ):
        """
        Method to show dialog with metrics for model.

        Args:
            cfg_info: The `ConfigFileInfo` for the model to show; if None,
                then show for model currently selected in table.
            metric_detail_widgets: Not user param; cache for widgets so that we
                don't create new window if user views same model twice.
        """
        if cfg_info is None:
            cfg_info = self.table_view.getSelectedRowItem()

        key = cfg_info.path
        if key not in metric_detail_widgets:
            metric_detail_widgets[key] = DetailedMetricsDialog(cfg_info)

        metric_detail_widgets[key].show()
        metric_detail_widgets[key].raise_()
        metric_detail_widgets[key].activateWindow()

MetricsTableModel

Bases: GenericTableModel

Model (i.e. Qt model/view) for table in MetricsTableDialog.

Source code in sleap/gui/dialogs/metrics.py
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
class MetricsTableModel(GenericTableModel):
    """
    Model (i.e. Qt model/view) for table in MetricsTableDialog.
    """

    properties = (
        "Path",
        "Timestamp",
        # "Run Name",
        "Model Type",
        "Architecture",
        "Training Instances",
        "Validation Instances",
        "OKS mAP",
        "Vis Precision",
        "Vis Recall",
        "Dist: 95%",
        "Dist: 75%",
        "Dist: Avg",
    )
    show_row_numbers = False

    def item_to_data(self, obj, cfg: ConfigFileInfo):
        if cfg.training_frame_count:
            n_train_str = (
                f"{cfg.training_instance_count} ({cfg.training_frame_count} frames)"
            )
        else:
            n_train_str = ""

        if cfg.validation_frame_count:
            n_val_str = (
                f"{cfg.validation_instance_count} ({cfg.validation_frame_count} frames)"
            )
        else:
            n_val_str = ""

        arch_str = get_backbone_from_omegaconf(cfg.config)

        backbone = cfg.config.model_config.backbone_config[arch_str]
        if "max_stride" in backbone:
            arch_str = f"{arch_str}, max stride: {backbone.max_stride}"
        if "filters" in backbone:
            arch_str = f"{arch_str}, filters: {backbone.filters}"

        # scale = cfg.config.data.preprocessing.input_scaling
        # if scale != 1.0:
        #     arch_str = f"{arch_str}, scale: {scale}"

        item_data = {
            "Timestamp": str(cfg.timestamp),
            # "Run Name": cfg.config.outputs.run_name,
            "Path": cfg.path_dir,
            "Model Type": cfg.head_name,
            "Architecture": arch_str,
            "Training Instances": n_train_str,
            "Validation Instances": n_val_str,
        }

        metrics = cfg.metrics

        # import pprint
        # pp = pprint.PrettyPrinter()
        # pp.pprint(metrics)

        if metrics:
            item_data = {
                **item_data,
                "OKS mAP": f"{metrics['oks_voc.mAP']:.5f}",
                "Vis Precision": f"{metrics['vis.precision']:.5f}",
                "Vis Recall": f"{metrics['vis.recall']:.5f}",
                "Dist: 95%": f"{metrics['dist.p95']:.5f}",
                "Dist: 75%": f"{metrics['dist.p75']:.5f}",
                "Dist: Avg": f"{metrics['dist.avg']:.5f}",
            }

        return item_data