Skip to content

config_utils

sleap.gui.config_utils

Functions:

Name Description
apply_cfg_transforms_to_key_val_dict

Transforms data from form to correct data types before converting to object.

filter_cfg

Filter out keys that start with underscore to get training config.

find_backbone_name_from_key_val_dict

Find the backbone model name from the config dictionary.

get_backbone_from_omegaconf

Get the backbone model name from the config.

get_head_from_omegaconf

Get the head model name from the config.

get_keyval_dict_from_omegaconf

Get a flat dictionary from an OmegaConf object.

get_omegaconf_from_gui_form

Get an OmegaConf object from a flat dictionary.

get_skeleton_from_config

Create Sleap-io Skeleton objects from config.

resolve_strides_from_key_val_dict

Find the valid max and output strides from the config dictionary.

apply_cfg_transforms_to_key_val_dict(key_val_dict)

Transforms data from form to correct data types before converting to object.

Parameters:

Name Type Description Default
key_val_dict

Flat dictionary from :py:class:TrainingEditorWidget.

required
Source code in sleap/gui/config_utils.py
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
def apply_cfg_transforms_to_key_val_dict(key_val_dict):
    """
    Transforms data from form to correct data types before converting to object.

    Arguments:
        key_val_dict: Flat dictionary from :py:class:`TrainingEditorWidget`.
    Returns:
        None, modifies dict in place.
    """
    if "_ensure_channels" in key_val_dict:
        ensure_channels = key_val_dict["_ensure_channels"].lower()
        ensure_rgb = False
        ensure_grayscale = False
        if ensure_channels == "rgb":
            ensure_rgb = True
        elif ensure_channels == "grayscale":
            ensure_grayscale = True

        key_val_dict["data_config.preprocessing.ensure_rgb"] = ensure_rgb
        key_val_dict["data_config.preprocessing.ensure_grayscale"] = ensure_grayscale

    # Overwrite backbone strides with stride from head.
    backbone_name = find_backbone_name_from_key_val_dict(key_val_dict)
    if backbone_name is not None:
        max_stride, output_stride = resolve_strides_from_key_val_dict(
            key_val_dict, backbone_name
        )
        key_val_dict[f"model_config.backbone_config.{backbone_name}.output_stride"] = (
            output_stride
        )
        key_val_dict[f"model_config.backbone_config.{backbone_name}.max_stride"] = (
            max_stride
        )

    # batch size for val
    key_val_dict["trainer_config.val_data_loader.batch_size"] = key_val_dict[
        "trainer_config.train_data_loader.batch_size"
    ]
    key_val_dict["trainer_config.val_data_loader.num_workers"] = key_val_dict[
        "trainer_config.train_data_loader.num_workers"
    ]

filter_cfg(cfg)

Filter out keys that start with underscore to get training config.

Source code in sleap/gui/config_utils.py
 6
 7
 8
 9
10
11
12
13
def filter_cfg(cfg):
    """Filter out keys that start with underscore to get training config."""
    for k, v in cfg.items():
        if not isinstance(v, DictConfig) and k.startswith("_"):
            del cfg[k]
        elif isinstance(v, DictConfig):
            filter_cfg(v)
    return cfg

find_backbone_name_from_key_val_dict(key_val_dict)

Find the backbone model name from the config dictionary.

Source code in sleap/gui/config_utils.py
56
57
58
59
60
61
62
63
def find_backbone_name_from_key_val_dict(key_val_dict: dict):
    """Find the backbone model name from the config dictionary."""
    backbone_name = None
    for key in key_val_dict:
        if key.startswith("model_config.backbone_config."):
            backbone_name = key.split(".")[2]

    return backbone_name

get_backbone_from_omegaconf(cfg)

Get the backbone model name from the config.

Source code in sleap/gui/config_utils.py
40
41
42
43
44
45
def get_backbone_from_omegaconf(cfg: OmegaConf):
    """Get the backbone model name from the config."""
    for k, v in cfg.model_config.backbone_config.items():
        if v is not None:
            return k
    return None

get_head_from_omegaconf(cfg)

Get the head model name from the config.

Source code in sleap/gui/config_utils.py
48
49
50
51
52
53
def get_head_from_omegaconf(cfg: OmegaConf):
    """Get the head model name from the config."""
    for k, v in cfg.model_config.head_configs.items():
        if v is not None:
            return k
    return None

get_keyval_dict_from_omegaconf(cfg, parent_key='', sep='.')

Get a flat dictionary from an OmegaConf object.

Source code in sleap/gui/config_utils.py
16
17
18
19
20
21
22
23
24
25
def get_keyval_dict_from_omegaconf(cfg, parent_key="", sep="."):
    """Get a flat dictionary from an OmegaConf object."""
    items = {}
    for k, v in cfg.items():
        new_key = f"{parent_key}{sep}{k}" if parent_key else k
        if isinstance(v, DictConfig):
            items.update(get_keyval_dict_from_omegaconf(v, new_key, sep=sep))
        else:
            items[new_key] = v
    return items

get_omegaconf_from_gui_form(flat_dict)

Get an OmegaConf object from a flat dictionary.

Source code in sleap/gui/config_utils.py
28
29
30
31
32
33
34
35
36
37
def get_omegaconf_from_gui_form(flat_dict):
    """Get an OmegaConf object from a flat dictionary."""
    result = {}
    for key, value in flat_dict.items():
        parts = key.split(".")
        d = result
        for p in parts[:-1]:
            d = d.setdefault(p, {})
        d[parts[-1]] = value
    return OmegaConf.create(result)

get_skeleton_from_config(skeleton_config)

Create Sleap-io Skeleton objects from config.

Parameters:

Name Type Description Default
skeleton_config OmegaConf

OmegaConf object containing the skeleton config.

required

Returns:

Type Description

Returns a list of sio.Skeleton objects created from the skeleton config stored in the training_config.yaml.

Source code in sleap/gui/config_utils.py
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
def get_skeleton_from_config(skeleton_config: OmegaConf):
    """Create Sleap-io Skeleton objects from config.

    Args:
        skeleton_config: OmegaConf object containing the skeleton config.

    Returns:
        Returns a list of `sio.Skeleton` objects created from the skeleton config
        stored in the `training_config.yaml`.

    """
    skeletons = []
    for skel_cfg in skeleton_config:
        skel = sio.Skeleton(
            nodes=[n["name"] for n in skel_cfg.nodes], name=skel_cfg.name
        )
        skel.add_edges(
            [(e["source"]["name"], e["destination"]["name"]) for e in skel_cfg.edges]
        )
        if skel_cfg.symmetries:
            for n1, n2 in skel_cfg.symmetries:
                skel.add_symmetry(n1["name"], n2["name"])

        skeletons.append(skel)

    return skeletons

resolve_strides_from_key_val_dict(key_val_dict, backbone_name)

Find the valid max and output strides from the config dictionary.

Source code in sleap/gui/config_utils.py
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
def resolve_strides_from_key_val_dict(
    key_val_dict: dict, backbone_name: str
) -> Tuple[int, int]:
    """Find the valid max and output strides from the config dictionary."""
    max_stride = key_val_dict.get(
        f"model_config.backbone_config.{backbone_name}.max_stride", None
    )
    output_stride = key_val_dict.get(
        f"model_config.backbone_config.{backbone_name}.output_stride", None
    )

    for key in [
        "model_config.head_configs.single_instance.confmaps.output_stride",
        "model_config.head_configs.centered_instance.confmaps.output_stride",
        "model_config.head_configs.centroid.confmaps.output_stride",
        "model_config.head_configs.bottomup.confmaps.output_stride",
        "model_config.head_configs.bottomup.pafs.output_stride",
        "model_config.head_configs.multi_class_topdown.confmaps.output_stride",
        "model_config.head_configs.multi_class_bottomup.confmaps.output_stride",
        "model_config.head_configs.multi_class_bottomup.class_maps.output_stride",
    ]:
        stride = key_val_dict.get(key, None)

        if stride is not None:
            stride = int(stride)
            max_stride = (
                max(int(max_stride), stride) if max_stride is not None else stride
            )
            output_stride = (
                min(int(output_stride), stride) if output_stride is not None else stride
            )

    if output_stride is None:
        output_stride = max_stride

    return max_stride, output_stride