Skip to content

configs

sleap.gui.learning.configs

Find, load, and show lists of saved TrainingJobConfig.

Classes:

Name Description
ConfigFileInfo

Object to represent a saved :py:class:TrainingJobConfig

TrainingConfigFilesWidget

Widget to show list of saved :py:class:TrainingJobConfig files.

TrainingConfigsGetter

Searches for and loads :py:class:TrainingJobConfig files.

ConfigFileInfo

Object to represent a saved :py:class:TrainingJobConfig

The :py:class:TrainingJobConfig class holds information about the model and can be saved as a file. This class holds information about that file, e.g., the path, and also provides some properties/methods that make it easier to access certain data in or about the file.

Attributes:

Name Type Description
config OmegaConf

the :py:class:TrainingJobConfig

path Optional[Text]

path to the :py:class:TrainingJobConfig

filename Optional[Text]

just the filename, not the full path

head_name Optional[Text]

string which should match name of model_config.head_configs key

dont_retrain bool

allows us to keep track of whether we should retrain this config

Source code in sleap/gui/learning/configs.py
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
@attr.s(auto_attribs=True, slots=True)
class ConfigFileInfo:
    """
    Object to represent a saved :py:class:`TrainingJobConfig`

    The :py:class:`TrainingJobConfig` class holds information about the model
    and can be saved as a file. This class holds information about that file,
    e.g., the path, and also provides some properties/methods that make it
    easier to access certain data in or about the file.

    Attributes:
        config: the :py:class:`TrainingJobConfig`
        path: path to the :py:class:`TrainingJobConfig`
        filename: just the filename, not the full path
        head_name: string which should match name of model_config.head_configs key
        dont_retrain: allows us to keep track of whether we should retrain
            this config
    """

    config: OmegaConf
    path: Optional[Text] = None
    filename: Optional[Text] = None
    head_name: Optional[Text] = None
    dont_retrain: bool = False
    _skeleton: Optional[Skeleton] = None
    _tried_finding_skeleton: bool = False
    _dset_len_cache: dict = attr.ib(factory=dict)

    @property
    def has_trained_model(self) -> bool:
        # TODO: inference only checks for the best model, so that's also
        #  what we'll do here, but both should check for other models
        #  depending on the training config settings.

        # allow to run inference on both torch weights (`.ckpt`) and keras weights
        # (`.h5`). sleap-nn supports running inference on the keras weights
        # (Note: currently only for unet models).
        # TODO: add support for running inference on the keras weights for other models.
        return (
            self._get_file_path("best.ckpt") is not None
            or self._get_file_path("best_model.h5") is not None
        )

    @property
    def path_dir(self):
        return (
            os.path.dirname(self.path)
            if (
                self.path.endswith("yaml")
                or self.path.endswith("json")
                or self.path.endswith("yml")
            )
            else self.path
        )

    def _get_file_path(self, shortname) -> Optional[Text]:
        """
        Check for specified file in various directories related config.

        Args:
            shortname: Filename without path.
        Returns:
            Full path + filename if found, otherwise None.
        """
        for dir in [
            OmegaConf.select(self.config, "trainer_config.ckpt_dir", default="."),
            self.path_dir,
        ]:
            full_path = os.path.join(dir, shortname)
            if os.path.exists(full_path):
                return full_path

        return None

    @property
    def metrics(self):
        return self._get_metrics("val")

    @property
    def skeleton(self):
        # cache skeleton so we only search once
        if self._skeleton is None and not self._tried_finding_skeleton:
            # if skeleton was saved in config, great!
            if self.config.data_config.skeletons:
                skeletons = get_skeleton_from_config(self.config.data_config.skeletons)
                self._skeleton = skeletons[0] if skeletons else None

            # otherwise try loading it from validation labels (much slower!)
            else:
                filename = self._get_file_path("labels_gt.val.slp")
                if filename is not None:
                    val_labels = load_file(filename)
                    if val_labels.skeletons:
                        self._skeleton = val_labels.skeletons[0]

            # don't try loading again (needed in case it's still None)
            self._tried_finding_skeleton = True

        return self._skeleton

    @property
    def training_instance_count(self):
        """Number of instances in the training dataset"""
        return self._get_dataset_len("instances", "train")

    @property
    def validation_instance_count(self):
        """Number of instances in the validation dataset"""
        return self._get_dataset_len("instances", "val")

    @property
    def training_frame_count(self):
        """Number of labeled frames in the training dataset"""
        return self._get_dataset_len("frames", "train")

    @property
    def validation_frame_count(self):
        """Number of labeled frames in the validation dataset"""
        return self._get_dataset_len("frames", "val")

    @property
    def timestamp(self):
        """Timestamp on file; parsed from filename (not OS timestamp)."""
        match = re.match(
            r".*?(?<!\d)(\d{2})(\d{2})(\d{2})_(\d{2})(\d{2})(\d{2})\b",
            self.config.trainer_config.run_name,
        )
        if match:
            year, month, day = int(match[1]), int(match[2]), int(match[3])
            hour, minute, sec = int(match[4]), int(match[5]), int(match[6])
            return datetime.datetime(2000 + year, month, day, hour, minute, sec)

        return None

    def _get_dataset_len(self, dset_name: Text, split_name: Text):
        cache_key = (dset_name, split_name)
        if cache_key not in self._dset_len_cache:
            n = None
            filename = (
                self._get_file_path(f"labels_gt.{split_name}.slp")
                if self._get_file_path(f"labels_gt.{split_name}.slp")
                else self._get_file_path(f"labels_{split_name}_gt_0.slp")
            )
            if filename is not None:
                with h5py.File(filename, "r") as f:
                    n = f[dset_name].shape[0]

            self._dset_len_cache[cache_key] = n

        return self._dset_len_cache[cache_key]

    def _get_metrics(self, split_name: Text):
        metrics_path_nn = self._get_file_path(f"{split_name}_0_pred_metrics.npz")

        if metrics_path_nn is None:
            metrics_path = self._get_file_path(f"metrics.{split_name}.npz")
        else:
            metrics_path = metrics_path_nn

        with np.load(metrics_path, allow_pickle=True) as data:
            metric_data = data["metrics"].item()

            return_dict = {
                "vis.tp": metric_data["visibility_metrics"].get("tp"),
                "vis.fp": metric_data["visibility_metrics"].get("fp"),
                "vis.tn": metric_data["visibility_metrics"].get("tn"),
                "vis.fn": metric_data["visibility_metrics"].get("fn"),
                "vis.precision": metric_data["visibility_metrics"].get("precision"),
                "vis.recall": metric_data["visibility_metrics"].get("recall"),
                "dist.dists": metric_data["distance_metrics"].get("dists"),
                "dist.avg": metric_data["distance_metrics"].get("avg"),
                "dist.p50": metric_data["distance_metrics"].get("p50"),
                "dist.p75": metric_data["distance_metrics"].get("p75"),
                "dist.p90": metric_data["distance_metrics"].get("p90"),
                "dist.p95": metric_data["distance_metrics"].get("p95"),
                "dist.p99": metric_data["distance_metrics"].get("p99"),
                "pck.mPCK": metric_data["pck_metrics"].get("mPCK"),
                "oks.mOKS": metric_data["mOKS"].get("mOKS"),
                "oks_voc.mAP": metric_data["voc_metrics"].get("oks_voc.mAP"),
                "oks_voc.mAR": metric_data["voc_metrics"].get("oks_voc.mAR"),
                "pck_voc.mAP": metric_data["voc_metrics"].get("pck_voc.mAP"),
                "pck_voc.mAR": metric_data["voc_metrics"].get("pck_voc.mAR"),
            }
            return return_dict

    @classmethod
    def from_config_file(cls, path: Text) -> "ConfigFileInfo":
        if path.endswith("yaml") or path.endswith("yml"):
            cfg = OmegaConf.load(path)

        else:
            from sleap_nn.config.training_job_config import (
                TrainingJobConfig as snn_TrainingJobConfig,
            )

            cfg = snn_TrainingJobConfig.load_sleap_config(path)
        head_name = get_head_from_omegaconf(cfg)
        filename = os.path.basename(path)
        return cls(config=cfg, path=path, filename=filename, head_name=head_name)

timestamp property

Timestamp on file; parsed from filename (not OS timestamp).

training_frame_count property

Number of labeled frames in the training dataset

training_instance_count property

Number of instances in the training dataset

validation_frame_count property

Number of labeled frames in the validation dataset

validation_instance_count property

Number of instances in the validation dataset

TrainingConfigFilesWidget

Bases: FieldComboWidget

Widget to show list of saved :py:class:TrainingJobConfig files.

This is used inside :py:class:TrainingEditorWidget.

Parameters:

Name Type Description Default
cfg_getter TrainingConfigsGetter

the :py:class:TrainingConfigsGetter from which menu is populated.

required
head_name Text

used to filter configs from cfg_getter.

required
require_trained bool

used to filter configs from cfg_getter.

False
Signals

Methods:

Name Description
doFileSelection

Shows file browser to add training profile for given model type.

getConfigInfoByMenuIdx

Return ConfigFileInfo for menu item index.

getSelectedConfigInfo

Return currently selected ConfigFileInfo (if any, None otherwise).

onSelectionIdxChange

Handler for when user selects a menu item.

setUserConfigData

Sets the user config option from settings made by user.

update

Updates menu options, optionally selecting a specific config.

Source code in sleap/gui/learning/configs.py
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
class TrainingConfigFilesWidget(FieldComboWidget):
    """
    Widget to show list of saved :py:class:`TrainingJobConfig` files.

    This is used inside :py:class:`TrainingEditorWidget`.

    Arguments:
        cfg_getter: the :py:class:`TrainingConfigsGetter` from which menu
            is populated.
        head_name: used to filter configs from `cfg_getter`.
        require_trained: used to filter configs from `cfg_getter`.

    Signals:
        onConfigSelection: triggered when user selects a config file

    """

    onConfigSelection = QtCore.Signal(ConfigFileInfo)

    SELECT_FILE_OPTION = "Select training config file..."
    SHOW_INITIAL_BLANK = 0

    def __init__(
        self,
        cfg_getter: "TrainingConfigsGetter",
        head_name: Text,
        require_trained: bool = False,
        *args,
        **kwargs,
    ):
        super(TrainingConfigFilesWidget, self).__init__(*args, **kwargs)
        self._cfg_getter = cfg_getter
        self._cfg_list = []
        self._head_name = head_name
        self._require_trained = require_trained
        self._user_config_data_dict = None

        self.currentIndexChanged.connect(self.onSelectionIdxChange)

    def update(self, select: Optional[ConfigFileInfo] = None):
        """Updates menu options, optionally selecting a specific config."""
        cfg_list = self._cfg_getter.get_filtered_configs(
            head_filter=self._head_name, only_trained=self._require_trained
        )
        self._cfg_list = cfg_list

        select_key = None

        option_list = []
        if self.SHOW_INITIAL_BLANK or len(cfg_list) == 0:
            option_list.append("")

        # add options for config files
        for cfg_info in cfg_list:
            cfg = cfg_info.config
            filename = cfg_info.filename

            display_name = ""

            if cfg_info.has_trained_model:
                display_name += "[Trained] "

            run_name = OmegaConf.select(cfg, "trainer_config.run_name", default="")
            display_name += f"{run_name}({filename})"

            if select is not None:
                if select.config == cfg_info.config:
                    select_key = display_name

            option_list.append(display_name)

        option_list.append("---")
        option_list.append(self.SELECT_FILE_OPTION)

        self.set_options(option_list, select_item=select_key)

    @property
    def _menu_cfg_idx_offset(self):
        if (
            hasattr(self, "options_list")
            and self.options_list
            and self.options_list[0] == ""
        ):
            return 1
        return 0

    def getConfigInfoByMenuIdx(self, menu_idx: int) -> Optional[ConfigFileInfo]:
        """Return `ConfigFileInfo` for menu item index."""
        cfg_idx = menu_idx - self._menu_cfg_idx_offset
        return self._cfg_list[cfg_idx] if 0 <= cfg_idx < len(self._cfg_list) else None

    def getSelectedConfigInfo(self) -> Optional[ConfigFileInfo]:
        """
        Return currently selected `ConfigFileInfo` (if any, None otherwise).
        """
        return self.getConfigInfoByMenuIdx(self.currentIndex())

    def onSelectionIdxChange(self, menu_idx: int):
        """
        Handler for when user selects a menu item.

        Either allows selection of config using file browser, or emits
        `onConfigSelection` signal for selected config.
        """
        if self.value() == self.SELECT_FILE_OPTION:
            cfg_info = self.doFileSelection()
            self._add_file_selection_to_menu(cfg_info)

        elif menu_idx >= self._menu_cfg_idx_offset:
            cfg_info = self.getConfigInfoByMenuIdx(menu_idx)
            if cfg_info:
                self.onConfigSelection.emit(cfg_info)

    def setUserConfigData(self, cfg_data_dict: Dict[Text, Any]):
        """Sets the user config option from settings made by user."""
        self._user_config_data_dict = cfg_data_dict

        # Select the "user config" option in the combobox menu
        if self.currentIndex() != 0:
            self.onSelectionIdxChange(menu_idx=0)

    def doFileSelection(self):
        """Shows file browser to add training profile for given model type."""
        filters = ["JSON (*.json)", "YAML (*.yaml;*.yml)"]
        filename, _ = FileDialog.open(
            None,
            dir=None,
            caption="Select training configuration file...",
            filter=";;".join(filters),
        )
        logging.debug(f"Selected training config file: {filename}")
        if not filename:
            logging.debug("No file selected for training config.")
            return None
        return self._cfg_getter.try_loading_path(filename)

    def _add_file_selection_to_menu(self, cfg_info: Optional[ConfigFileInfo] = None):
        if cfg_info:
            # We were able to load config from selected file,
            # so add to options and select it.
            self._cfg_getter.insert_first(cfg_info)
            self.update(select=cfg_info)

            if cfg_info.head_name != self._head_name:
                QtWidgets.QMessageBox(
                    text=f"The file you selected was a training config for "
                    f"{cfg_info.head_name} and cannot be used for "
                    f"{self._head_name}."
                ).exec_()
        else:
            # We couldn't load a valid config, so change menu to initial
            # item since this is "user" config.
            self.setCurrentIndex(0)

            QtWidgets.QMessageBox(
                text="The file you selected was not a valid training config."
            ).exec_()

doFileSelection()

Shows file browser to add training profile for given model type.

Source code in sleap/gui/learning/configs.py
349
350
351
352
353
354
355
356
357
358
359
360
361
362
def doFileSelection(self):
    """Shows file browser to add training profile for given model type."""
    filters = ["JSON (*.json)", "YAML (*.yaml;*.yml)"]
    filename, _ = FileDialog.open(
        None,
        dir=None,
        caption="Select training configuration file...",
        filter=";;".join(filters),
    )
    logging.debug(f"Selected training config file: {filename}")
    if not filename:
        logging.debug("No file selected for training config.")
        return None
    return self._cfg_getter.try_loading_path(filename)

getConfigInfoByMenuIdx(menu_idx)

Return ConfigFileInfo for menu item index.

Source code in sleap/gui/learning/configs.py
314
315
316
317
def getConfigInfoByMenuIdx(self, menu_idx: int) -> Optional[ConfigFileInfo]:
    """Return `ConfigFileInfo` for menu item index."""
    cfg_idx = menu_idx - self._menu_cfg_idx_offset
    return self._cfg_list[cfg_idx] if 0 <= cfg_idx < len(self._cfg_list) else None

getSelectedConfigInfo()

Return currently selected ConfigFileInfo (if any, None otherwise).

Source code in sleap/gui/learning/configs.py
319
320
321
322
323
def getSelectedConfigInfo(self) -> Optional[ConfigFileInfo]:
    """
    Return currently selected `ConfigFileInfo` (if any, None otherwise).
    """
    return self.getConfigInfoByMenuIdx(self.currentIndex())

onSelectionIdxChange(menu_idx)

Handler for when user selects a menu item.

Either allows selection of config using file browser, or emits onConfigSelection signal for selected config.

Source code in sleap/gui/learning/configs.py
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
def onSelectionIdxChange(self, menu_idx: int):
    """
    Handler for when user selects a menu item.

    Either allows selection of config using file browser, or emits
    `onConfigSelection` signal for selected config.
    """
    if self.value() == self.SELECT_FILE_OPTION:
        cfg_info = self.doFileSelection()
        self._add_file_selection_to_menu(cfg_info)

    elif menu_idx >= self._menu_cfg_idx_offset:
        cfg_info = self.getConfigInfoByMenuIdx(menu_idx)
        if cfg_info:
            self.onConfigSelection.emit(cfg_info)

setUserConfigData(cfg_data_dict)

Sets the user config option from settings made by user.

Source code in sleap/gui/learning/configs.py
341
342
343
344
345
346
347
def setUserConfigData(self, cfg_data_dict: Dict[Text, Any]):
    """Sets the user config option from settings made by user."""
    self._user_config_data_dict = cfg_data_dict

    # Select the "user config" option in the combobox menu
    if self.currentIndex() != 0:
        self.onSelectionIdxChange(menu_idx=0)

update(select=None)

Updates menu options, optionally selecting a specific config.

Source code in sleap/gui/learning/configs.py
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
def update(self, select: Optional[ConfigFileInfo] = None):
    """Updates menu options, optionally selecting a specific config."""
    cfg_list = self._cfg_getter.get_filtered_configs(
        head_filter=self._head_name, only_trained=self._require_trained
    )
    self._cfg_list = cfg_list

    select_key = None

    option_list = []
    if self.SHOW_INITIAL_BLANK or len(cfg_list) == 0:
        option_list.append("")

    # add options for config files
    for cfg_info in cfg_list:
        cfg = cfg_info.config
        filename = cfg_info.filename

        display_name = ""

        if cfg_info.has_trained_model:
            display_name += "[Trained] "

        run_name = OmegaConf.select(cfg, "trainer_config.run_name", default="")
        display_name += f"{run_name}({filename})"

        if select is not None:
            if select.config == cfg_info.config:
                select_key = display_name

        option_list.append(display_name)

    option_list.append("---")
    option_list.append(self.SELECT_FILE_OPTION)

    self.set_options(option_list, select_item=select_key)

TrainingConfigsGetter

Searches for and loads :py:class:TrainingJobConfig files.

Attributes:

Name Type Description
dir_paths List[Text]

List of paths in which to search for :py:class:TrainingJobConfig files.

head_filter Optional[Text]

Name of head type to use when filtering, e.g., "centered_instance".

search_depth int

How many subdirectories deep to search for config files.

Methods:

Name Description
find_configs

Load configs from all saved paths.

get_filtered_configs

Returns filtered subset of loaded configs.

get_first

Get first loaded config.

insert_first

Insert config at beginning of list.

make_from_labels_filename

Makes object which checks for models in default subdir for dataset.

try_loading_path

Attempts to load config file and wrap in ConfigFileInfo object.

update

Re-searches paths and loads any previously unloaded config files.

Source code in sleap/gui/learning/configs.py
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
@attr.s(auto_attribs=True)
class TrainingConfigsGetter:
    """
    Searches for and loads :py:class:`TrainingJobConfig` files.

    Attributes:
        dir_paths: List of paths in which to search for
            :py:class:`TrainingJobConfig` files.
        head_filter: Name of head type to use when filtering,
            e.g., "centered_instance".
        search_depth: How many subdirectories deep to search for config files.
    """

    dir_paths: List[Text]
    head_filter: Optional[Text] = None
    search_depth: int = 1
    _configs: List[ConfigFileInfo] = attr.ib(default=attr.Factory(list))

    def __attrs_post_init__(self):
        self._configs = self.find_configs()

    def update(self):
        """Re-searches paths and loads any previously unloaded config files."""
        if len(self._configs) == 0:
            self._configs = self.find_configs()
        else:
            current_cfg_paths = {cfg.path for cfg in self._configs}
            new_cfgs = [
                cfg for cfg in self.find_configs() if cfg.path not in current_cfg_paths
            ]
            self._configs = new_cfgs + self._configs

    def find_configs(self) -> List[ConfigFileInfo]:
        """Load configs from all saved paths."""
        configs = []

        # Collect all configs from specified directories, sorted from most recently
        # modified to least
        for config_dir in filter(lambda d: os.path.exists(d), self.dir_paths):
            # Find all json files in dir and subdirs to specified depth
            json_files = sleap_utils.find_files_by_suffix(
                config_dir, ".json", depth=self.search_depth
            )
            json_files.extend(
                sleap_utils.find_files_by_suffix(
                    config_dir, ".yaml", depth=self.search_depth
                )
            )
            json_files.extend(
                sleap_utils.find_files_by_suffix(
                    config_dir, ".yml", depth=self.search_depth
                )
            )

            if Path(config_dir).as_posix().endswith("sleap/training_profiles"):
                # Use hardcoded sort.
                BUILTIN_ORDER = [
                    "baseline.centroid.yaml",
                    "baseline_medium_rf.bottomup.yaml",
                    "baseline_medium_rf.single.yaml",
                    "baseline_medium_rf.topdown.yaml",
                    "baseline_large_rf.bottomup.yaml",
                    "baseline_large_rf.single.yaml",
                    "baseline_large_rf.topdown.yaml",
                ]
                json_files.sort(key=lambda f: BUILTIN_ORDER.index(f.name))

            else:
                # Sort files, starting with most recently modified
                json_files.sort(key=lambda f: f.stat().st_mtime, reverse=True)

            # Load the configs from files
            for json_path in [file.path for file in json_files]:
                cfg_info = self.try_loading_path(json_path)
                if cfg_info:
                    configs.append(cfg_info)

        return configs

    def get_filtered_configs(
        self, head_filter: Text = "", only_trained: bool = False
    ) -> List[ConfigFileInfo]:
        """Returns filtered subset of loaded configs."""

        base_config_dir = os.path.realpath(
            sleap_utils.get_package_file("training_profiles")
        )

        cfgs_to_return = []
        paths_included = []

        for cfg_info in self._configs:
            if cfg_info.head_name == head_filter or not head_filter:
                if not only_trained or cfg_info.has_trained_model:
                    # At this point we know that config is appropriate
                    # for this head type and is trained if that is required.

                    # We just want a single config from each model directory.
                    # Taking the first config we see in the directory means
                    # we'll get the *trained* config if there is one, since
                    # it will be newer and we've sorted by desc date modified.

                    # TODO: check filenames since timestamp sort could be off
                    #  if files were copied

                    cfg_dir = os.path.realpath(os.path.dirname(cfg_info.path))

                    if cfg_dir == base_config_dir or cfg_dir not in paths_included:
                        paths_included.append(cfg_dir)
                        cfgs_to_return.append(cfg_info)

        return cfgs_to_return

    def get_first(self) -> Optional[ConfigFileInfo]:
        """Get first loaded config."""
        return self._configs[0] if self._configs else None

    def insert_first(self, cfg_info: ConfigFileInfo):
        """Insert config at beginning of list."""
        self._configs.insert(0, cfg_info)

    def try_loading_path(self, path: Text) -> Optional[ConfigFileInfo]:
        """Attempts to load config file and wrap in `ConfigFileInfo` object."""
        if path.endswith("yaml") or path.endswith("yml"):
            # Get the head from the model (i.e., what the model will predict)
            from omegaconf import OmegaConf

            cfg = OmegaConf.load(path)
            key = get_head_from_omegaconf(cfg)

            filename = os.path.basename(path)
            logging.debug(f"Loaded YAML config file: {filename}")

            # If filter isn't set or matches head name, add config to list
            if self.head_filter in (None, key):
                logging.debug(f"Config file matches head filter: {self.head_filter}")
                # Try mapping to TrainingJobConfig
                try:
                    return ConfigFileInfo(
                        path=path, filename=filename, config=cfg, head_name=key
                    )
                except Exception as e:
                    # Couldn't map so just ignore
                    logging.error(f"Error mapping YAML config: {e}")
                    return None
        else:
            # Get the head from the model (i.e., what the model will predict)
            try:
                from sleap_nn.config.training_job_config import (
                    TrainingJobConfig as snn_TrainingJobConfig,
                )

                cfg = snn_TrainingJobConfig.load_sleap_config(path)
            except Exception as e:
                # Couldn't load so just ignore
                print(e)
                pass
            else:
                # Get the head from the model (i.e., what the model will predict)
                key = get_head_from_omegaconf(cfg)

                filename = os.path.basename(path)

                # If filter isn't set or matches head name, add config to list
                if self.head_filter in (None, key):
                    return ConfigFileInfo(
                        path=path, filename=filename, config=cfg, head_name=key
                    )

        return None

    @classmethod
    def make_from_labels_filename(
        cls, labels_filename: Text, head_filter: Optional[Text] = None
    ) -> "TrainingConfigsGetter":
        """
        Makes object which checks for models in default subdir for dataset.
        """
        dir_paths = []
        if labels_filename:
            labels_model_dir = os.path.join(os.path.dirname(labels_filename), "models")
            dir_paths.append(labels_model_dir)

        base_config_dir = sleap_utils.get_package_file("training_profiles")
        dir_paths.append(base_config_dir)

        return cls(dir_paths=dir_paths, head_filter=head_filter)

find_configs()

Load configs from all saved paths.

Source code in sleap/gui/learning/configs.py
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
def find_configs(self) -> List[ConfigFileInfo]:
    """Load configs from all saved paths."""
    configs = []

    # Collect all configs from specified directories, sorted from most recently
    # modified to least
    for config_dir in filter(lambda d: os.path.exists(d), self.dir_paths):
        # Find all json files in dir and subdirs to specified depth
        json_files = sleap_utils.find_files_by_suffix(
            config_dir, ".json", depth=self.search_depth
        )
        json_files.extend(
            sleap_utils.find_files_by_suffix(
                config_dir, ".yaml", depth=self.search_depth
            )
        )
        json_files.extend(
            sleap_utils.find_files_by_suffix(
                config_dir, ".yml", depth=self.search_depth
            )
        )

        if Path(config_dir).as_posix().endswith("sleap/training_profiles"):
            # Use hardcoded sort.
            BUILTIN_ORDER = [
                "baseline.centroid.yaml",
                "baseline_medium_rf.bottomup.yaml",
                "baseline_medium_rf.single.yaml",
                "baseline_medium_rf.topdown.yaml",
                "baseline_large_rf.bottomup.yaml",
                "baseline_large_rf.single.yaml",
                "baseline_large_rf.topdown.yaml",
            ]
            json_files.sort(key=lambda f: BUILTIN_ORDER.index(f.name))

        else:
            # Sort files, starting with most recently modified
            json_files.sort(key=lambda f: f.stat().st_mtime, reverse=True)

        # Load the configs from files
        for json_path in [file.path for file in json_files]:
            cfg_info = self.try_loading_path(json_path)
            if cfg_info:
                configs.append(cfg_info)

    return configs

get_filtered_configs(head_filter='', only_trained=False)

Returns filtered subset of loaded configs.

Source code in sleap/gui/learning/configs.py
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
def get_filtered_configs(
    self, head_filter: Text = "", only_trained: bool = False
) -> List[ConfigFileInfo]:
    """Returns filtered subset of loaded configs."""

    base_config_dir = os.path.realpath(
        sleap_utils.get_package_file("training_profiles")
    )

    cfgs_to_return = []
    paths_included = []

    for cfg_info in self._configs:
        if cfg_info.head_name == head_filter or not head_filter:
            if not only_trained or cfg_info.has_trained_model:
                # At this point we know that config is appropriate
                # for this head type and is trained if that is required.

                # We just want a single config from each model directory.
                # Taking the first config we see in the directory means
                # we'll get the *trained* config if there is one, since
                # it will be newer and we've sorted by desc date modified.

                # TODO: check filenames since timestamp sort could be off
                #  if files were copied

                cfg_dir = os.path.realpath(os.path.dirname(cfg_info.path))

                if cfg_dir == base_config_dir or cfg_dir not in paths_included:
                    paths_included.append(cfg_dir)
                    cfgs_to_return.append(cfg_info)

    return cfgs_to_return

get_first()

Get first loaded config.

Source code in sleap/gui/learning/configs.py
500
501
502
def get_first(self) -> Optional[ConfigFileInfo]:
    """Get first loaded config."""
    return self._configs[0] if self._configs else None

insert_first(cfg_info)

Insert config at beginning of list.

Source code in sleap/gui/learning/configs.py
504
505
506
def insert_first(self, cfg_info: ConfigFileInfo):
    """Insert config at beginning of list."""
    self._configs.insert(0, cfg_info)

make_from_labels_filename(labels_filename, head_filter=None) classmethod

Makes object which checks for models in default subdir for dataset.

Source code in sleap/gui/learning/configs.py
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
@classmethod
def make_from_labels_filename(
    cls, labels_filename: Text, head_filter: Optional[Text] = None
) -> "TrainingConfigsGetter":
    """
    Makes object which checks for models in default subdir for dataset.
    """
    dir_paths = []
    if labels_filename:
        labels_model_dir = os.path.join(os.path.dirname(labels_filename), "models")
        dir_paths.append(labels_model_dir)

    base_config_dir = sleap_utils.get_package_file("training_profiles")
    dir_paths.append(base_config_dir)

    return cls(dir_paths=dir_paths, head_filter=head_filter)

try_loading_path(path)

Attempts to load config file and wrap in ConfigFileInfo object.

Source code in sleap/gui/learning/configs.py
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
def try_loading_path(self, path: Text) -> Optional[ConfigFileInfo]:
    """Attempts to load config file and wrap in `ConfigFileInfo` object."""
    if path.endswith("yaml") or path.endswith("yml"):
        # Get the head from the model (i.e., what the model will predict)
        from omegaconf import OmegaConf

        cfg = OmegaConf.load(path)
        key = get_head_from_omegaconf(cfg)

        filename = os.path.basename(path)
        logging.debug(f"Loaded YAML config file: {filename}")

        # If filter isn't set or matches head name, add config to list
        if self.head_filter in (None, key):
            logging.debug(f"Config file matches head filter: {self.head_filter}")
            # Try mapping to TrainingJobConfig
            try:
                return ConfigFileInfo(
                    path=path, filename=filename, config=cfg, head_name=key
                )
            except Exception as e:
                # Couldn't map so just ignore
                logging.error(f"Error mapping YAML config: {e}")
                return None
    else:
        # Get the head from the model (i.e., what the model will predict)
        try:
            from sleap_nn.config.training_job_config import (
                TrainingJobConfig as snn_TrainingJobConfig,
            )

            cfg = snn_TrainingJobConfig.load_sleap_config(path)
        except Exception as e:
            # Couldn't load so just ignore
            print(e)
            pass
        else:
            # Get the head from the model (i.e., what the model will predict)
            key = get_head_from_omegaconf(cfg)

            filename = os.path.basename(path)

            # If filter isn't set or matches head name, add config to list
            if self.head_filter in (None, key):
                return ConfigFileInfo(
                    path=path, filename=filename, config=cfg, head_name=key
                )

    return None

update()

Re-searches paths and loads any previously unloaded config files.

Source code in sleap/gui/learning/configs.py
408
409
410
411
412
413
414
415
416
417
def update(self):
    """Re-searches paths and loads any previously unloaded config files."""
    if len(self._configs) == 0:
        self._configs = self.find_configs()
    else:
        current_cfg_paths = {cfg.path for cfg in self._configs}
        new_cfgs = [
            cfg for cfg in self.find_configs() if cfg.path not in current_cfg_paths
        ]
        self._configs = new_cfgs + self._configs