Skip to content

util

sleap.util

A miscellaneous set of utility functions.

Note: to avoid circular imports, this file is used for utility functions that do not depend on any other modules in the package.

Try not to put things in here unless they really have no other place.

Classes:

Name Description
RateColumn

Renders the progress rate.

Functions:

Name Description
attr_to_dtype

Converts classes with basic types to numpy composite dtypes.

dict_cut

Helper function for creating subdictionary by numeric indexing of items.

find_files_by_suffix

Returns list of files matching suffix, optionally searching in subdirs.

frame_list

Converts 'n-m' string to list of ints.

get_config_file

Returns the full path to the specified config file.

get_package_file

Returns full path to specified file within sleap package.

imgfig

Create a tight figure for image plotting.

json_dumps

A simple wrapper around the JSON encoder we are using.

json_loads

A simple wrapper around the JSON decoder we are using.

make_scoped_dictionary

Converts dictionary with scoped keys to dictionary of dictionaries.

parse_uri_path

Parse a URI starting with 'file:///' to a posix path.

plot_img

Plot an image in a tight figure.

plot_instance

Plot a single instance with edge coloring.

plot_instances

Plot a list of instances with identity coloring.

resize_image

Resizes single image with shape (height, width, channels).

save_dict_to_hdf5

Saves dictionary to an HDF5 file.

uniquify

Returns unique elements from list, preserving order.

usable_cpu_count

Gets number of CPUs usable by the current process.

weak_filename_match

Check if paths probably point to same file.

RateColumn

Bases: ProgressColumn

Renders the progress rate.

Methods:

Name Description
render

Show progress rate.

Source code in sleap/util.py
45
46
47
48
49
50
51
52
53
class RateColumn(rich.progress.ProgressColumn):
    """Renders the progress rate."""

    def render(self, task: Task) -> rich.progress.Text:
        """Show progress rate."""
        speed = task.speed
        if speed is None:
            return rich.progress.Text("?", style="progress.data.speed")
        return rich.progress.Text(f"{speed:.1f} FPS", style="progress.data.speed")

render(task)

Show progress rate.

Source code in sleap/util.py
48
49
50
51
52
53
def render(self, task: Task) -> rich.progress.Text:
    """Show progress rate."""
    speed = task.speed
    if speed is None:
        return rich.progress.Text("?", style="progress.data.speed")
    return rich.progress.Text(f"{speed:.1f} FPS", style="progress.data.speed")

attr_to_dtype(cls)

Converts classes with basic types to numpy composite dtypes.

Parameters:

Name Type Description Default
cls Any

class to convert

required

Returns:

Type Description

numpy dtype.

Source code in sleap/util.py
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
def attr_to_dtype(cls: Any):
    """Converts classes with basic types to numpy composite dtypes.

    Arguments:
        cls: class to convert

    Returns:
        numpy dtype.
    """
    dtype_list = []
    for field in attr.fields(cls):
        if field.type == str:
            dtype_list.append((field.name, h5.special_dtype(vlen=str)))
        elif field.type is None:
            raise TypeError(
                f"numpy dtype for {cls} cannot be constructed because no "
                + "type information found. Make sure each field is type annotated."
            )
        elif field.type in [str, int, float, bool]:
            dtype_list.append((field.name, field.type))
        else:
            raise TypeError(
                f"numpy dtype for {cls} cannot be constructed because no "
                + f"{field.type} is not supported."
            )

    return np.dtype(dtype_list)

dict_cut(d, a, b)

Helper function for creating subdictionary by numeric indexing of items.

Assumes that dict.items() will have a fixed order.

Parameters:

Name Type Description Default
d Dict

The dictionary to "split"

required
a int

Start index of range of items to include in result.

required
b int

End index of range of items to include in result.

required

Returns:

Type Description
Dict

A dictionary that contains a subset of the items in the original dict.

Source code in sleap/util.py
270
271
272
273
274
275
276
277
278
279
280
281
282
283
def dict_cut(d: Dict, a: int, b: int) -> Dict:
    """Helper function for creating subdictionary by numeric indexing of items.

    Assumes that `dict.items()` will have a fixed order.

    Args:
        d: The dictionary to "split"
        a: Start index of range of items to include in result.
        b: End index of range of items to include in result.

    Returns:
        A dictionary that contains a subset of the items in the original dict.
    """
    return dict(list(d.items())[a:b])

find_files_by_suffix(root_dir, suffix, prefix='', depth=0)

Returns list of files matching suffix, optionally searching in subdirs.

Parameters:

Name Type Description Default
root_dir str

Path to directory where we start searching

required
suffix str

File suffix to match (e.g., '.json')

required
prefix str

Optional file prefix to match

''
depth int

How many subdirectories deep to keep searching

0

Returns:

Type Description
List[DirEntry]

List of os.DirEntry objects.

Source code in sleap/util.py
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
def find_files_by_suffix(
    root_dir: str, suffix: str, prefix: str = "", depth: int = 0
) -> List[os.DirEntry]:
    """Returns list of files matching suffix, optionally searching in subdirs.

    Args:
        root_dir: Path to directory where we start searching
        suffix: File suffix to match (e.g., '.json')
        prefix: Optional file prefix to match
        depth: How many subdirectories deep to keep searching

    Returns:
        List of os.DirEntry objects.
    """

    with os.scandir(root_dir) as file_iterator:
        files = [file for file in file_iterator]

    subdir_paths = [file.path for file in files if file.is_dir()]
    matching_files = [
        file
        for file in files
        if file.is_file()
        and file.name.endswith(suffix)
        and (not prefix or file.name.startswith(prefix))
    ]

    if depth:
        for subdir in subdir_paths:
            matching_files.extend(
                find_files_by_suffix(subdir, suffix, prefix, depth=depth - 1)
            )

    return matching_files

frame_list(frame_str)

Converts 'n-m' string to list of ints.

Parameters:

Name Type Description Default
frame_str str

string representing range

required

Returns:

Type Description
Optional[List[int]]

List of ints, or None if string does not represent valid range.

Source code in sleap/util.py
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
def frame_list(frame_str: str) -> Optional[List[int]]:
    """Converts 'n-m' string to list of ints.

    Args:
        frame_str: string representing range

    Returns:
        List of ints, or None if string does not represent valid range.
    """

    # Handle ranges of frames. Must be of the form "1-200" (or "1,-200")
    if "-" in frame_str:
        min_max = frame_str.split("-")
        min_frame = int(min_max[0].rstrip(","))
        max_frame = int(min_max[1])
        return list(range(min_frame, max_frame + 1))

    return [int(x) for x in frame_str.split(",")] if len(frame_str) else None

get_config_file(shortname, ignore_file_not_found=False, get_defaults=False)

Returns the full path to the specified config file.

The config file will be at ~/.sleap//

If that file doesn't yet exist, we'll look for a file inside the package config directory (sleap/config) and copy the file into the user's config directory (creating the directory if needed).

Parameters:

Name Type Description Default
shortname str

The short filename, e.g., shortcuts.yaml

required
ignore_file_not_found bool

If True, then return path for config file regardless of whether it exists.

False
get_defaults bool

If True, then just return the path to default config file.

False

Raises:

Type Description
FileNotFoundError

If the specified config file cannot be found.

Returns:

Type Description
str

The full path to the specified config file.

Source code in sleap/util.py
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
def get_config_file(
    shortname: str, ignore_file_not_found: bool = False, get_defaults: bool = False
) -> str:
    """Returns the full path to the specified config file.

    The config file will be at ~/.sleap/<version>/<shortname>

    If that file doesn't yet exist, we'll look for a <shortname> file inside
    the package config directory (sleap/config) and copy the file into the
    user's config directory (creating the directory if needed).

    Args:
        shortname: The short filename, e.g., shortcuts.yaml
        ignore_file_not_found: If True, then return path for config file
            regardless of whether it exists.
        get_defaults: If True, then just return the path to default config file.

    Raises:
        FileNotFoundError: If the specified config file cannot be found.

    Returns:
        The full path to the specified config file.
    """

    desired_path = Path.home() / f".sleap/{sleap_version.__version__}/{shortname}"

    # Make sure there's a ~/.sleap/<version>/ directory to store user version of
    # the config file.
    desired_path.parent.mkdir(parents=True, exist_ok=True)

    # If we don't care whether the file exists, just return the path
    if ignore_file_not_found:
        return desired_path

    # If we do care whether the file exists, check the package version of the
    # config file if we can't find the user version.
    if get_defaults or not desired_path.exists():
        package_path = get_package_file(f"config/{shortname}")
        package_path = Path(package_path)
        if not package_path.exists():
            raise FileNotFoundError(
                f"Cannot locate {shortname} config file at {desired_path} or "
                f"{package_path}."
            )

        if get_defaults:
            return package_path

        # Copy package version of config file into user config directory.
        shutil.copy(package_path, desired_path)

    return desired_path

get_package_file(filename)

Returns full path to specified file within sleap package.

Source code in sleap/util.py
286
287
288
289
290
def get_package_file(filename: str) -> str:
    """Returns full path to specified file within sleap package."""

    data_path: Path = files("sleap").joinpath(filename)
    return data_path.as_posix()

imgfig(size=6, dpi=72, scale=1.0)

Create a tight figure for image plotting.

Parameters:

Name Type Description Default
size float | tuple

Scalar or 2-tuple specifying the (width, height) of the figure in inches. If scalar, will assume equal width and height.

6
dpi int

Dots per inch, controlling the resolution of the image.

72
scale float

Factor to scale the size of the figure by. This is a convenience for increasing the size of the plot at the same DPI.

1.0

Returns:

Type Description
Figure

A matplotlib.figure.Figure to use for plotting.

Source code in sleap/util.py
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
def imgfig(
    size: float | tuple = 6, dpi: int = 72, scale: float = 1.0
) -> matplotlib.figure.Figure:
    """Create a tight figure for image plotting.

    Args:
        size: Scalar or 2-tuple specifying the (width, height) of the figure in inches.
            If scalar, will assume equal width and height.
        dpi: Dots per inch, controlling the resolution of the image.
        scale: Factor to scale the size of the figure by. This is a convenience for
            increasing the size of the plot at the same DPI.

    Returns:
        A matplotlib.figure.Figure to use for plotting.
    """
    if not isinstance(size, (tuple, list)):
        size = (size, size)
    fig = plt.figure(figsize=(scale * size[0], scale * size[1]), dpi=dpi)
    ax = fig.add_axes([0, 0, 1, 1], frameon=False)
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    plt.autoscale(tight=True)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.grid(False)
    return fig

json_dumps(d, filename=None)

A simple wrapper around the JSON encoder we are using.

Parameters:

Name Type Description Default
d Dict

The dict to write.

required
filename str

The filename to write to.

None

Returns:

Type Description

None

Source code in sleap/util.py
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
def json_dumps(d: Dict, filename: str = None):
    """A simple wrapper around the JSON encoder we are using.

    Args:
        d: The dict to write.
        filename: The filename to write to.

    Returns:
        None
    """

    encoder = rapidjson

    if filename:
        with open(filename, "w") as f:
            encoder.dump(d, f, ensure_ascii=False)
    else:
        return encoder.dumps(d)

json_loads(json_str)

A simple wrapper around the JSON decoder we are using.

Parameters:

Name Type Description Default
json_str str

JSON string to decode.

required

Returns:

Type Description
Dict

Result of decoding JSON string.

Source code in sleap/util.py
56
57
58
59
60
61
62
63
64
65
66
67
68
def json_loads(json_str: str) -> Dict:
    """A simple wrapper around the JSON decoder we are using.

    Args:
        json_str: JSON string to decode.

    Returns:
        Result of decoding JSON string.
    """
    try:
        return rapidjson.loads(json_str)
    except Exception:
        return json.loads(json_str)

make_scoped_dictionary(flat_dict, exclude_nones=True)

Converts dictionary with scoped keys to dictionary of dictionaries.

Parameters:

Name Type Description Default
flat_dict Dict[str, Any]

The dictionary to convert. Keys should be strings with scope.foo format.

required
exclude_nodes

Whether to exclude items where value is None.

required

Returns:

Type Description
Dict[str, Dict[str, Any]]

Dictionary in which keys are scope and values are dictionary with foo (etc) as keys and original value of scope.foo as value.

Source code in sleap/util.py
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
def make_scoped_dictionary(
    flat_dict: Dict[str, Any], exclude_nones: bool = True
) -> Dict[str, Dict[str, Any]]:
    """Converts dictionary with scoped keys to dictionary of dictionaries.

    Args:
        flat_dict: The dictionary to convert. Keys should be strings with
            `scope.foo` format.
        exclude_nodes: Whether to exclude items where value is None.

    Returns:
        Dictionary in which keys are `scope` and values are dictionary with
            `foo` (etc) as keys and original value of `scope.foo` as value.
    """
    scoped_dict = defaultdict(dict)

    for key, val in flat_dict.items():
        if "." in key and (not exclude_nones or val is not None):
            scope, subkey = key.split(".")

            scoped_dict[scope][subkey] = val

    return scoped_dict

parse_uri_path(uri)

Parse a URI starting with 'file:///' to a posix path.

Source code in sleap/util.py
421
422
423
def parse_uri_path(uri: str) -> str:
    """Parse a URI starting with 'file:///' to a posix path."""
    return Path(url2pathname(urlparse(unquote(uri)).path)).as_posix()

plot_img(img, dpi=72, scale=1.0)

Plot an image in a tight figure.

Source code in sleap/util.py
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
def plot_img(
    img: np.ndarray, dpi: int = 72, scale: float = 1.0
) -> matplotlib.figure.Figure:
    """Plot an image in a tight figure."""
    if hasattr(img, "numpy"):
        img = img.numpy()

    if img.shape[0] == 1:
        # Squeeze out batch singleton dimension.
        img = img.squeeze(axis=0)

    # Check if image is grayscale (single channel).
    grayscale = img.shape[-1] == 1
    if grayscale:
        # Squeeze out singleton channel.
        img = img.squeeze(axis=-1)

    # Normalize the range of pixel values.
    img_min = img.min()
    img_max = img.max()
    if img_min < 0.0 or img_max > 1.0:
        img = (img - img_min) / (img_max - img_min)

    fig = imgfig(
        size=(float(img.shape[1]) / dpi, float(img.shape[0]) / dpi),
        dpi=dpi,
        scale=scale,
    )

    ax = fig.gca()
    ax.imshow(
        img,
        cmap="gray" if grayscale else None,
        origin="upper",
        extent=[-0.5, img.shape[1] - 0.5, img.shape[0] - 0.5, -0.5],
    )
    return fig

plot_instance(instance, skeleton=None, cmap=None, color_by_node=False, lw=2, ms=10, bbox=None, scale=1.0, **kwargs)

Plot a single instance with edge coloring.

Source code in sleap/util.py
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
def plot_instance(
    instance,
    skeleton=None,
    cmap=None,
    color_by_node=False,
    lw=2,
    ms=10,
    bbox=None,
    scale=1.0,
    **kwargs,
):
    """Plot a single instance with edge coloring."""
    if cmap is None:
        cmap = sns.color_palette("tab20")

    if skeleton is None and hasattr(instance, "skeleton"):
        skeleton = instance.skeleton

    if skeleton is None:
        color_by_node = True
    else:
        if len(skeleton.edges) == 0:
            color_by_node = True

    if hasattr(instance, "numpy"):
        inst_pts = instance.numpy()
    else:
        inst_pts = instance

    h_lines = []
    if color_by_node:
        for k, (x, y) in enumerate(inst_pts):
            if bbox is not None:
                x -= bbox[1]
                y -= bbox[0]

            x *= scale
            y *= scale

            h_lines_k = plt.plot(x, y, ".", ms=ms, c=cmap[k % len(cmap)], **kwargs)
            h_lines.append(h_lines_k)

    else:
        for k, (src_node, dst_node) in enumerate(skeleton.edges):
            src_pt = instance.points_array[instance.skeleton.node_to_index(src_node)]
            dst_pt = instance.points_array[instance.skeleton.node_to_index(dst_node)]

            x = np.array([src_pt[0], dst_pt[0]])
            y = np.array([src_pt[1], dst_pt[1]])

            if bbox is not None:
                x -= bbox[1]
                y -= bbox[0]

            x *= scale
            y *= scale

            h_lines_k = plt.plot(
                x, y, ".-", ms=ms, lw=lw, c=cmap[k % len(cmap)], **kwargs
            )

            h_lines.append(h_lines_k)

    return h_lines

plot_instances(instances, skeleton=None, cmap=None, color_by_track=False, tracks=None, **kwargs)

Plot a list of instances with identity coloring.

Source code in sleap/util.py
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
def plot_instances(
    instances, skeleton=None, cmap=None, color_by_track=False, tracks=None, **kwargs
):
    """Plot a list of instances with identity coloring."""

    if cmap is None:
        cmap = sns.color_palette("tab10")

    if color_by_track and tracks is None:
        # Infer tracks for ordering if not provided.
        tracks = set()
        for instance in instances:
            tracks.add(instance.track)

        # Sort by spawned frame.
        tracks = sorted(list(tracks), key=lambda track: track.name)

    h_lines = []
    for i, instance in enumerate(instances):
        if color_by_track:
            if instance.track is None:
                raise ValueError(
                    "Instances must have a set track when coloring by track."
                )

            if instance.track not in tracks:
                raise ValueError("Instance has a track not found in specified tracks.")

            color = cmap[tracks.index(instance.track) % len(cmap)]

        else:
            # Color by identity (order in list).
            color = cmap[i % len(cmap)]

        h_lines_i = plot_instance(instance, skeleton=skeleton, cmap=[color], **kwargs)
        h_lines.append(h_lines_i)

    return h_lines

resize_image(img, scale)

Resizes single image with shape (height, width, channels).

Source code in sleap/util.py
205
206
207
208
209
210
211
212
213
214
215
216
217
218
def resize_image(img: np.ndarray, scale: float) -> np.ndarray:
    """Resizes single image with shape (height, width, channels)."""
    height, width, channels = img.shape
    new_height, new_width = int(height // (1 / scale)), int(width // (1 / scale))

    # Note that OpenCV takes shape as (width, height).

    if channels == 1:
        # opencv doesn't want a single channel to have its own dimension
        img = cv2.resize(img[:, :], (new_width, new_height))[..., None]
    else:
        img = cv2.resize(img, (new_width, new_height))

    return img

save_dict_to_hdf5(h5file, path, dic)

Saves dictionary to an HDF5 file.

Calls itself recursively if items in dictionary are not np.ndarray, np.int64, np.float64, str, or bytes. Objects must be iterable.

Parameters:

Name Type Description Default
h5file File

The HDF5 filename object to save the data to. Assume it is open.

required
path str

The path to group save the dict under.

required
dic dict

The dict to save.

required

Raises:

Type Description
ValueError

If type for item in dict cannot be saved.

Returns:

Type Description

None

Source code in sleap/util.py
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
def save_dict_to_hdf5(h5file: h5.File, path: str, dic: dict):
    """Saves dictionary to an HDF5 file.

    Calls itself recursively if items in dictionary are not
    `np.ndarray`, `np.int64`, `np.float64`, `str`, or bytes.
    Objects must be iterable.

    Args:
        h5file: The HDF5 filename object to save the data to.
            Assume it is open.
        path: The path to group save the dict under.
        dic: The dict to save.

    Raises:
        ValueError: If type for item in dict cannot be saved.


    Returns:
        None
    """
    for key, item in list(dic.items()):
        print(f"Saving {key}:")
        if item is None:
            h5file[path + key] = ""
        elif isinstance(item, bool):
            h5file[path + key] = int(item)
        elif isinstance(item, list):
            items_encoded = []
            for it in item:
                if isinstance(it, str):
                    items_encoded.append(it.encode("utf8"))
                else:
                    items_encoded.append(it)

            h5file[path + key] = np.asarray(items_encoded)
        elif isinstance(item, (str)):
            h5file[path + key] = item.encode("utf8")
        elif isinstance(item, (np.ndarray, np.int64, np.float64, str, bytes, float)):
            h5file[path + key] = item
        elif isinstance(item, dict):
            save_dict_to_hdf5(h5file, path + key + "/", item)
        elif isinstance(item, int):
            h5file[path + key] = item
        else:
            raise ValueError("Cannot save %s type" % type(item))

uniquify(seq)

Returns unique elements from list, preserving order.

Note: This will not work on Python 3.5 or lower since dicts don't preserve order.

Parameters:

Name Type Description Default
seq Iterable[Hashable]

The list to remove duplicates from.

required

Returns:

Type Description
List

The unique elements from the input list extracted in original order.

Source code in sleap/util.py
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
def uniquify(seq: Iterable[Hashable]) -> List:
    """Returns unique elements from list, preserving order.

    Note: This will not work on Python 3.5 or lower since dicts don't
    preserve order.

    Args:
        seq: The list to remove duplicates from.

    Returns:
        The unique elements from the input list extracted in original
        order.
    """

    # Raymond Hettinger
    # https://twitter.com/raymondh/status/944125570534621185
    return list(dict.fromkeys(seq))

usable_cpu_count()

Gets number of CPUs usable by the current process.

Takes into consideration cpusets restrictions.

Returns:

Type Description
int

The number of usable cpus

Source code in sleap/util.py
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
def usable_cpu_count() -> int:
    """Gets number of CPUs usable by the current process.

    Takes into consideration cpusets restrictions.

    Returns:
        The number of usable cpus
    """
    try:
        result = len(os.sched_getaffinity(0))
    except AttributeError:
        try:
            result = len(psutil.Process().cpu_affinity())
        except AttributeError:
            result = os.cpu_count()
    return result

weak_filename_match(filename_a, filename_b)

Check if paths probably point to same file.

Compares the filename and names of two directories up.

Parameters:

Name Type Description Default
filename_a str

first path to check

required
filename_b str

path to check against first path

required

Returns:

Type Description
bool

True if the paths probably match.

Source code in sleap/util.py
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
def weak_filename_match(filename_a: str, filename_b: str) -> bool:
    """Check if paths probably point to same file.

    Compares the filename and names of two directories up.

    Args:
        filename_a: first path to check
        filename_b: path to check against first path

    Returns:
        True if the paths probably match.
    """
    # convert all path separators to /
    filename_a = filename_a.replace("\\", "/")
    filename_b = filename_b.replace("\\", "/")

    # remove unique pid so we can match tmp directories for same zip
    filename_a = re.sub(r"/tmp_\d+_", "tmp_", filename_a)
    filename_b = re.sub(r"/tmp_\d+_", "tmp_", filename_b)

    # check if last three parts of path match
    return filename_a.split("/")[-3:] == filename_b.split("/")[-3:]