Skip to content

Plotting Utils

SimplePlot

Cleanly plot a figure with matplotlib.pyplot fig, ax = plt.subplots()

with gd.SPlot():
    fig, ax = plt.subplots()
    gd.despine(fig)
    ax.plot()
    fig.tight_layout()
Source code in gdutils/utils/plotting.py
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
class SimplePlot:
    """
    Cleanly plot a figure with matplotlib.pyplot
    fig, ax = plt.subplots()

        with gd.SPlot():
            fig, ax = plt.subplots()
            gd.despine(fig)
            ax.plot()
            fig.tight_layout()

    """

    def __init__(
        self,
        fname: Optional[str] = None,
        show: bool = True,
        save: Optional[bool] = None,
    ):
        """
        :param fname: figure filename
        :param show: if `plt.show` must be called
        :param save: if `fig.savefig` must be called
        """
        self._fname = fname
        self._save = fname is not None if save is None else save
        self._show = show

    def __enter__(self) -> "SimplePlot":
        return self

    def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
        """
        Get the current figure, save/show it if specified and finally close the figure properly
        """
        # If an exception occurred, don't try to save/show the plot, just cleanup
        if exc_type is not None:
            plt.close(plt.gcf())
            return

        fig = plt.gcf()

        if self._fname is not None and self._save:
            fig.savefig(self._fname, bbox_inches="tight")

        if self._show:
            plt.show()

        plt.close(fig)

    def save(self, fname: str) -> None:
        """Give a filename to save the figure"""
        self._fname = fname
        self._save = True

__exit__(exc_type, exc_val, exc_tb)

Get the current figure, save/show it if specified and finally close the figure properly

Source code in gdutils/utils/plotting.py
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
    """
    Get the current figure, save/show it if specified and finally close the figure properly
    """
    # If an exception occurred, don't try to save/show the plot, just cleanup
    if exc_type is not None:
        plt.close(plt.gcf())
        return

    fig = plt.gcf()

    if self._fname is not None and self._save:
        fig.savefig(self._fname, bbox_inches="tight")

    if self._show:
        plt.show()

    plt.close(fig)

__init__(fname=None, show=True, save=None)

:param fname: figure filename :param show: if plt.show must be called :param save: if fig.savefig must be called

Source code in gdutils/utils/plotting.py
216
217
218
219
220
221
222
223
224
225
226
227
228
229
def __init__(
    self,
    fname: Optional[str] = None,
    show: bool = True,
    save: Optional[bool] = None,
):
    """
    :param fname: figure filename
    :param show: if `plt.show` must be called
    :param save: if `fig.savefig` must be called
    """
    self._fname = fname
    self._save = fname is not None if save is None else save
    self._show = show

save(fname)

Give a filename to save the figure

Source code in gdutils/utils/plotting.py
253
254
255
256
def save(self, fname: str) -> None:
    """Give a filename to save the figure"""
    self._fname = fname
    self._save = True

despine(fig=None, ax=None, top=True, right=True, left=False, bottom=False, offset=None, trim=False)

Remove the top and right spines from plot(s).

matplotlib figure, optional

Figure to despine all axes of, defaults to the current figure.

ax : matplotlib axes, optional Specific axes object to despine. Ignored if fig is provided. top, right, left, bottom : boolean, optional If True, remove that spine. offset : int or dict, optional Absolute distance, in points, spines should be moved away from the axes (negative values move spines inward). A single value applies to all spines; a dict can be used to set offset values per side. trim : bool, optional If True, limit spines to the smallest and largest major tick on each non-despined axis.

Returns

None

Source code in gdutils/utils/plotting.py
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
def despine(
    fig: Optional[mpl.figure.Figure] = None,
    ax: Optional[mpl.axes.Axes] = None,
    top: bool = True,
    right: bool = True,
    left: bool = False,
    bottom: bool = False,
    offset: Optional[Union[int, Dict[str, int]]] = None,
    trim: bool = False,
) -> None:
    """Remove the top and right spines from plot(s).

    fig : matplotlib figure, optional
        Figure to despine all axes of, defaults to the current figure.
    ax : matplotlib axes, optional
        Specific axes object to despine. Ignored if fig is provided.
    top, right, left, bottom : boolean, optional
        If True, remove that spine.
    offset : int or dict, optional
        Absolute distance, in points, spines should be moved away
        from the axes (negative values move spines inward). A single value
        applies to all spines; a dict can be used to set offset values per
        side.
    trim : bool, optional
        If True, limit spines to the smallest and largest major tick
        on each non-despined axis.

    Returns
    -------
    None

    """
    # Get references to the axes we want
    axes = None
    if fig is None and ax is None:
        axes = plt.gcf().axes
    elif fig is not None:
        axes = fig.axes
    elif ax is not None:
        axes = [ax]

    for ax_i in axes:
        for side in ["top", "right", "left", "bottom"]:
            # Toggle the spine objects
            is_visible = not locals()[side]
            ax_i.spines[side].set_visible(is_visible)
            if offset is not None and is_visible:
                try:
                    val = offset.get(side, 0)
                except AttributeError:
                    val = offset
                ax_i.spines[side].set_position(("outward", val))

        # Potentially move the ticks
        if left and not right:
            maj_on = any(t.tick1line.get_visible() for t in ax_i.yaxis.majorTicks)
            min_on = any(t.tick1line.get_visible() for t in ax_i.yaxis.minorTicks)
            ax_i.yaxis.set_ticks_position("right")
            for t in ax_i.yaxis.majorTicks:
                t.tick2line.set_visible(maj_on)
            for t in ax_i.yaxis.minorTicks:
                t.tick2line.set_visible(min_on)

        if bottom and not top:
            maj_on = any(t.tick1line.get_visible() for t in ax_i.xaxis.majorTicks)
            min_on = any(t.tick1line.get_visible() for t in ax_i.xaxis.minorTicks)
            ax_i.xaxis.set_ticks_position("top")
            for t in ax_i.xaxis.majorTicks:
                t.tick2line.set_visible(maj_on)
            for t in ax_i.xaxis.minorTicks:
                t.tick2line.set_visible(min_on)

        if trim:
            # clip off the parts of the spines that extend past major ticks
            xticks = np.asarray(ax_i.get_xticks())
            if xticks.size:
                firsttick = np.compress(xticks >= min(ax_i.get_xlim()), xticks)[0]
                lasttick = np.compress(xticks <= max(ax_i.get_xlim()), xticks)[-1]
                ax_i.spines["bottom"].set_bounds(firsttick, lasttick)
                ax_i.spines["top"].set_bounds(firsttick, lasttick)
                newticks = xticks.compress(xticks <= lasttick)
                newticks = newticks.compress(newticks >= firsttick)
                ax_i.set_xticks(newticks)

            yticks = np.asarray(ax_i.get_yticks())
            if yticks.size:
                firsttick = np.compress(yticks >= min(ax_i.get_ylim()), yticks)[0]
                lasttick = np.compress(yticks <= max(ax_i.get_ylim()), yticks)[-1]
                ax_i.spines["left"].set_bounds(firsttick, lasttick)
                ax_i.spines["right"].set_bounds(firsttick, lasttick)
                newticks = yticks.compress(yticks <= lasttick)
                newticks = newticks.compress(newticks >= firsttick)
                ax_i.set_yticks(newticks)

get_color_cycle()

Return the list of colors in the current matplotlib color cycle

Parameters

None

Returns

colors : list List of matplotlib colors in the current cycle, or dark gray if the current color cycle is empty.

Source code in gdutils/utils/plotting.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
def get_color_cycle() -> List[str]:
    """Return the list of colors in the current matplotlib color cycle

    Parameters
    ----------
    None

    Returns
    -------
    colors : list
        List of matplotlib colors in the current cycle, or dark gray if
        the current color cycle is empty.
    """
    cycler = mpl.rcParams["axes.prop_cycle"]
    return cycler.by_key()["color"] if "color" in cycler.keys else [".15"]

move_legend(obj, loc, **kwargs)

Recreate a plot's legend at a new location.

The name is a slight misnomer. Matplotlib legends do not expose public control over their position parameters. So this function creates a new legend, copying over the data from the original object, which is then removed.

Parameters

obj : the object with the plot This argument can be either a seaborn or matplotlib object:

- :class:`seaborn.FacetGrid` or :class:`seaborn.PairGrid`
- :class:`matplotlib.axes.Axes` or :class:`matplotlib.figure.Figure`
str or int

Location argument, as in :meth:matplotlib.axes.Axes.legend.

kwargs Other keyword arguments are passed to :meth:matplotlib.axes.Axes.legend.

Examples

.. include:: ../docstrings/move_legend.rst

Source code in gdutils/utils/plotting.py
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
def move_legend(obj: Any, loc: Union[str, int], **kwargs: Any) -> None:
    """
    Recreate a plot's legend at a new location.

    The name is a slight misnomer. Matplotlib legends do not expose public
    control over their position parameters. So this function creates a new legend,
    copying over the data from the original object, which is then removed.

    Parameters
    ----------
    obj : the object with the plot
        This argument can be either a seaborn or matplotlib object:

        - :class:`seaborn.FacetGrid` or :class:`seaborn.PairGrid`
        - :class:`matplotlib.axes.Axes` or :class:`matplotlib.figure.Figure`

    loc : str or int
        Location argument, as in :meth:`matplotlib.axes.Axes.legend`.

    kwargs
        Other keyword arguments are passed to :meth:`matplotlib.axes.Axes.legend`.

    Examples
    --------

    .. include:: ../docstrings/move_legend.rst

    """
    # This is a somewhat hackish solution that will hopefully be obviated by
    # upstream improvements to matplotlib legends that make them easier to
    # modify after creation.

    # Locate the legend object and a method to recreate the legend
    if isinstance(obj, mpl.axes.Axes):
        old_legend = obj.legend_
        legend_func = obj.legend
    elif isinstance(obj, mpl.figure.Figure):
        if obj.legends:
            old_legend = obj.legends[-1]
        else:
            old_legend = None
        legend_func = obj.legend
    else:
        err = "`obj` must be a seaborn Grid or matplotlib Axes or Figure instance."
        raise TypeError(err)

    if old_legend is None:
        err = f"{obj} has no legend attached."
        raise ValueError(err)

    # Extract the components of the legend we need to reuse
    handles = old_legend.legend_handles

    labels = [t.get_text() for t in old_legend.get_texts()]

    # Extract legend properties that can be passed to the recreation method
    # (Vexingly, these don't all round-trip)
    legend_kws = inspect.signature(mpl.legend.Legend).parameters
    props = {k: v for k, v in old_legend.properties().items() if k in legend_kws}

    # Delegate default bbox_to_anchor rules to matplotlib
    props.pop("bbox_to_anchor", None)

    # Try to propagate the existing title and font properties; respect new ones too
    title = props.pop("title")
    if "title" in kwargs:
        title.set_text(kwargs.pop("title"))
    title_kwargs = {k: v for k, v in kwargs.items() if k.startswith("title_")}
    for key, val in title_kwargs.items():
        title.set(**{key[6:]: val})
        kwargs.pop(key)

    # Try to respect the frame visibility
    kwargs.setdefault("frameon", old_legend.legendPatch.get_visible())

    # Remove the old legend and create the new one
    props.update(kwargs)
    old_legend.remove()
    new_legend = legend_func(handles, labels, loc=loc, **props)
    new_legend.set_title(title.get_text(), title.get_fontproperties())