Skip to content

pafs

sleap.gui.overlays.pafs

Overlay for part affinity fields.

Currently a DataOverlay gets data from a model (i.e., it runs inference on the current frame) and then uses a MultiQuiverPlot object to show the resulting part affinity fields.

Classes:

Name Description
MultiQuiverPlot

QtWidgets.QGraphicsObject to display multiple quiver plots in a

QuiverPlot

QtWidgets.QGraphicsObject for drawing single quiver plot.

MultiQuiverPlot

Bases: QGraphicsObject

QtWidgets.QGraphicsObject to display multiple quiver plots in a QtWidgets.QGraphicsView.

Parameters:

Name Type Description Default
frame array

Data for one frame of quiver plot data. Shape of array should be (channels, height, width).

None
show list

List of channels to show. If None, show all channels.

None
decimation int

Decimation factor. If 1, show every arrow.

1

Returns:

Type Description

None.

Note

Each channel corresponds to two (h, w) arrays: x and y for the vector.

When initialized, creates one child QuiverPlot item for each channel.

Methods:

Name Description
boundingRect

Method required by Qt.

paint

Method required by Qt.

Source code in sleap/gui/overlays/pafs.py
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
class MultiQuiverPlot(QtWidgets.QGraphicsObject):
    """QtWidgets.QGraphicsObject to display multiple quiver plots in a
    QtWidgets.QGraphicsView.

    Args:
        frame (numpy.array): Data for one frame of quiver plot data.
            Shape of array should be (channels, height, width).
        show (list, optional): List of channels to show. If None, show all channels.
        decimation (int, optional): Decimation factor. If 1, show every arrow.

    Returns:
        None.

    Note:
        Each channel corresponds to two (h, w) arrays: x and y for the vector.

    When initialized, creates one child QuiverPlot item for each channel.
    """

    def __init__(
        self,
        frame: np.array = None,
        show: list = None,
        decimation: int = 1,
        scale: float = 1.0,
        *args,
        **kwargs,
    ):
        super(MultiQuiverPlot, self).__init__(*args, **kwargs)
        self.frame = frame
        self.affinity_field = []
        self.decimation = decimation
        self.scale = scale

        # if data range is outside [-1, 1], assume it's [-255, 255] and scale
        if np.ptp(self.frame) > 4:
            self.frame = self.frame.astype(np.float64) / 255

        if show is None:
            self.show_list = range(self.frame.shape[2] // 2)
        else:
            self.show_list = show
        for channel in self.show_list:
            if channel < self.frame.shape[-1] // 2:
                color_map = h5_colors[channel % len(h5_colors)]
                aff_field_item = QuiverPlot(
                    field_x=self.frame[..., channel * 2],
                    field_y=self.frame[..., channel * 2 + 1],
                    color=color_map,
                    decimation=self.decimation,
                    scale=self.scale,
                    parent=self,
                )
                self.affinity_field.append(aff_field_item)

    def boundingRect(self) -> QtCore.QRectF:
        """Method required by Qt."""
        return QtCore.QRectF()

    def paint(self, painter, option, widget=None):
        """Method required by Qt."""
        pass

boundingRect()

Method required by Qt.

Source code in sleap/gui/overlays/pafs.py
75
76
77
def boundingRect(self) -> QtCore.QRectF:
    """Method required by Qt."""
    return QtCore.QRectF()

paint(painter, option, widget=None)

Method required by Qt.

Source code in sleap/gui/overlays/pafs.py
79
80
81
def paint(self, painter, option, widget=None):
    """Method required by Qt."""
    pass

QuiverPlot

Bases: QGraphicsObject

QtWidgets.QGraphicsObject for drawing single quiver plot.

Parameters:

Name Type Description Default
field_x array

(h, w) array of x component of vectors.

None
field_y array

(h, w) array of y component of vectors.

None
color list

Arrow color. Format as (r, g, b) array.

[255, 255, 255]
decimation int

Decimation factor. If 1, show every arrow.

1

Returns:

Type Description

None.

Methods:

Name Description
boundingRect

Method called by Qt in order to determine whether object is in

paint

Method called by Qt to draw object.

Source code in sleap/gui/overlays/pafs.py
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
class QuiverPlot(QtWidgets.QGraphicsObject):
    """QtWidgets.QGraphicsObject for drawing single quiver plot.

    Args:
        field_x (numpy.array): (h, w) array of x component of vectors.
        field_y (numpy.array): (h, w) array of y component of vectors.
        color (list, optional): Arrow color. Format as (r, g, b) array.
        decimation (int, optional): Decimation factor. If 1, show every arrow.

    Returns:
        None.
    """

    def __init__(
        self,
        field_x: np.array = None,
        field_y: np.array = None,
        color=[255, 255, 255],
        decimation=1,
        scale=1,
        *args,
        **kwargs,
    ):
        super(QuiverPlot, self).__init__(*args, **kwargs)

        self.field_x, self.field_y = None, None
        self.color = color
        self.decimation = decimation
        self.scale = scale
        pen_width = min(4, max(0.1, math.log(self.decimation, 20)))
        self.pen = QtGui.QPen(QtGui.QColor(*self.color), pen_width)
        self.points = []
        self.rect = QtCore.QRectF()

        if field_x is not None and field_y is not None:
            self.field_x, self.field_y = field_x, field_y

            h, w = self.field_x.shape
            h, w = int(h * self.scale), int(w * self.scale)

            self.rect = QtCore.QRectF(0, 0, w, h)

            self._add_arrows()

    def _add_arrows(self, min_length=0.01):
        points = []
        if self.field_x is not None and self.field_y is not None:
            raw_delta_yx = np.stack((self.field_y, self.field_x), axis=-1)

            dim_0 = self.field_x.shape[0] // self.decimation * self.decimation
            dim_1 = self.field_x.shape[1] // self.decimation * self.decimation

            grid = np.mgrid[0 : dim_0 : self.decimation, 0 : dim_1 : self.decimation]
            loc_yx = np.moveaxis(grid, 0, -1)

            # Adjust by scaling factor
            loc_yx = loc_yx * self.scale

            if self.decimation > 1:
                delta_yx = self._decimate(raw_delta_yx, self.decimation)

                # Shift locations to midpoint of decimation square
                loc_yx += self.decimation // 2
            else:
                delta_yx = raw_delta_yx

            delta_yx = delta_yx * self.scale

            # Split into x,y matrices
            loc_y, loc_x = loc_yx[..., 0], loc_yx[..., 1]
            delta_y, delta_x = delta_yx[..., 0], delta_yx[..., 1]

            # Determine vector endpoint
            x2 = delta_x * self.decimation + loc_x
            y2 = delta_y * self.decimation + loc_y
            line_length = (delta_x**2 + delta_y**2) ** 0.5

            # Determine points for arrow
            arrow_head_size = line_length / 4

            u_dx = np.divide(
                delta_x, line_length, out=np.zeros_like(delta_x), where=line_length != 0
            )
            u_dy = np.divide(
                delta_y, line_length, out=np.zeros_like(delta_y), where=line_length != 0
            )
            p1_x = x2 - u_dx * arrow_head_size - u_dy * arrow_head_size
            p1_y = y2 - u_dy * arrow_head_size + u_dx * arrow_head_size

            p2_x = x2 - u_dx * arrow_head_size + u_dy * arrow_head_size
            p2_y = y2 - u_dy * arrow_head_size - u_dx * arrow_head_size

            # Build list of QPointF objects for faster drawing
            y_x_pairs = itertools.product(
                range(delta_yx.shape[0]), range(delta_yx.shape[1])
            )
            for y, x in y_x_pairs:
                x1, y1 = loc_x[y, x], loc_y[y, x]

                if line_length[y, x] > min_length:
                    points.append((x1, y1))
                    points.append((x2[y, x], y2[y, x]))
                    points.append((p1_x[y, x], p1_y[y, x]))
                    points.append((x2[y, x], y2[y, x]))
                    points.append((p2_x[y, x], p2_y[y, x]))
                    points.append((x2[y, x], y2[y, x]))
            self.points = list(itertools.starmap(QtCore.QPointF, points))

    def _decimate(self, image: np.array, box: int):
        height = width = box
        # Source: https://stackoverflow.com/questions/48482317/
        # slice-an-image-into-tiles-using-numpy
        _nrows, _ncols, depth = image.shape
        _size = image.size
        _strides = image.strides

        nrows, _m = divmod(_nrows, height)
        ncols, _n = divmod(_ncols, width)
        if _m != 0 or _n != 0:
            # if we can't tile whole image, forget about bottom/right edges
            image = image[: (nrows + 1) * box, : (ncols + 1) * box]

        tiles = np.lib.stride_tricks.as_strided(
            np.ravel(image),
            shape=(nrows, ncols, height, width, depth),
            strides=(height * _strides[0], width * _strides[1], *_strides),
            writeable=False,
        )

        # Since strides accesses the ndarray by memory, we need to swap axes if
        # the array is stored column-major (Fortran), which it is from h5py.
        if _strides[0] < _strides[1]:
            tiles = np.swapaxes(tiles, 0, 1)

        return np.mean(tiles, axis=(2, 3))

    def boundingRect(self) -> QtCore.QRectF:
        """Method called by Qt in order to determine whether object is in
        visible frame."""
        return QtCore.QRectF(self.rect)

    def paint(self, painter, option, widget=None):
        """Method called by Qt to draw object."""
        if self.pen is not None:
            painter.setPen(self.pen)
        painter.drawLines(self.points)
        pass

boundingRect()

Method called by Qt in order to determine whether object is in visible frame.

Source code in sleap/gui/overlays/pafs.py
220
221
222
223
def boundingRect(self) -> QtCore.QRectF:
    """Method called by Qt in order to determine whether object is in
    visible frame."""
    return QtCore.QRectF(self.rect)

paint(painter, option, widget=None)

Method called by Qt to draw object.

Source code in sleap/gui/overlays/pafs.py
225
226
227
228
229
230
def paint(self, painter, option, widget=None):
    """Method called by Qt to draw object."""
    if self.pen is not None:
        painter.setPen(self.pen)
    painter.drawLines(self.points)
    pass