Source code for upsetplot.plotting

from __future__ import print_function, division, absolute_import

import itertools

import numpy as np
import pandas as pd
import matplotlib
from matplotlib import pyplot as plt
from matplotlib.tight_layout import get_renderer


def _process_data(df, sort_by, sort_sets_by, sum_over):
    if df.ndim == 1:
        data = df
        df = pd.DataFrame({'_value': df})

        if not data.index.is_unique:
            data = (data
                    .groupby(level=list(range(data.index.nlevels)))
                    .sum())
        if sum_over is not None:
            raise ValueError('sum_over is not applicable when the input is a '
                             'Series')
    elif sum_over is None:
        raise ValueError('sum_over must be False or a column name when a '
                         'DataFrame is input')
    else:
        gb = df.groupby(level=list(range(df.index.nlevels)))
        if sum_over is False:
            data = gb.size()
            data.name = 'size'
        elif hasattr(sum_over, 'lower'):
            data = gb[sum_over].sum()
        else:
            raise ValueError('Unsupported value for sum_over: %r' % sum_over)

    # check all indices are boolean
    assert all(set([True, False]) >= set(level) for level in data.index.levels)

    totals = [data[data.index.get_level_values(name).values.astype(bool)].sum()
              for name in data.index.names]
    totals = pd.Series(totals, index=data.index.names)
    if sort_sets_by == 'cardinality':
        totals.sort_values(ascending=False, inplace=True)
    elif sort_sets_by is not None:
        raise ValueError('Unknown sort_sets_by: %r' % sort_sets_by)
    df = df.reorder_levels(totals.index.values)
    data = data.reorder_levels(totals.index.values)

    if sort_by == 'cardinality':
        data = data.sort_values(ascending=False)
    elif sort_by == 'degree':
        comb = itertools.combinations
        o = pd.DataFrame([{name: True for name in names}
                          for i in range(data.index.nlevels + 1)
                          for names in comb(data.index.names, i)],
                         columns=data.index.names)
        o.fillna(False, inplace=True)
        o = o.astype(bool)
        o.set_index(data.index.names, inplace=True)
        data = data.reindex(index=o.index)
    else:
        raise ValueError('Unknown sort_by: %r' % sort_by)

    min_value = 0
    max_value = np.inf
    data = data[np.logical_and(data >= min_value, data <= max_value)]

    # add '_bin' to df indicating index in data
    # XXX: ugly!
    def _pack_binary(X):
        X = pd.DataFrame(X)
        out = 0
        for i, (_, col) in enumerate(X.items()):
            out *= 2
            out += col
        return out

    df_packed = _pack_binary(df.index.to_frame())
    data_packed = _pack_binary(data.index.to_frame())
    df['_bin'] = pd.Series(df_packed).map(
        pd.Series(np.arange(len(data_packed)),
                  index=data_packed))

    return df, data, totals


class _Transposed:
    """Wrap an object in order to transpose some plotting operations

    Attributes of obj will be mapped.
    Keyword arguments when calling obj will be mapped.

    The mapping is not recursive: callable attributes need to be _Transposed
    again.
    """

    def __init__(self, obj):
        self.__obj = obj

    def __getattr__(self, key):
        return getattr(self.__obj, self._NAME_TRANSPOSE.get(key, key))

    def __call__(self, *args, **kwargs):
        return self.__obj(*args, **{self._NAME_TRANSPOSE.get(k, k): v
                                    for k, v in kwargs.items()})

    _NAME_TRANSPOSE = {
        'width': 'height',
        'height': 'width',
        'hspace': 'wspace',
        'wspace': 'hspace',
        'hlines': 'vlines',
        'vlines': 'hlines',
        'bar': 'barh',
        'barh': 'bar',
        'xaxis': 'yaxis',
        'yaxis': 'xaxis',
        'left': 'bottom',
        'right': 'top',
        'top': 'right',
        'bottom': 'left',
        'sharex': 'sharey',
        'sharey': 'sharex',
        'get_figwidth': 'get_figheight',
        'get_figheight': 'get_figwidth',
        'set_figwidth': 'set_figheight',
        'set_figheight': 'set_figwidth',
        'set_xlabel': 'set_ylabel',
        'set_ylabel': 'set_xlabel',
    }


def _transpose(obj):
    if isinstance(obj, str):
        return _Transposed._NAME_TRANSPOSE.get(obj, obj)
    return _Transposed(obj)


def _identity(obj):
    return obj


[docs]class UpSet: """Manage the data and drawing for a basic UpSet plot Primary public method is :meth:`plot`. Parameters ---------- data : pandas.Series or pandas.DataFrame Values for each set to plot. Should have multi-index where each level is binary, corresponding to set membership. If a DataFrame, `sum_over` must be a string or False. orientation : {'horizontal' (default), 'vertical'} If horizontal, intersections are listed from left to right. sort_by : {'cardinality', 'degree'} If 'cardinality', set intersections are listed from largest to smallest value. If 'degree', they are listed in order of the number of sets intersected. sort_sets_by : {'cardinality', None} Whether to sort the overall sets by total cardinality, or leave them in the provided order. sum_over : str, False or None (default) Must be specified when `data` is a DataFrame. If False, the intersection plot will show the count of each subset. Otherwise, it shows the sum of the specified field. facecolor : str Color for bar charts and dots. with_lines : bool Whether to show lines joining dots in the matrix, to mark multiple sets being intersected. element_size : float or None Side length in pt. If None, size is estimated to fit figure intersection_plot_elements : int The intersections plot should be large enough to fit this many matrix elements. totals_plot_elements : int The totals plot should be large enough to fit this many matrix elements. show_counts : bool or str, default=False Whether to label the intersection size bars with the cardinality of the intersection. When a string, this formats the number. For example, '%d' is equivalent to True. """ _default_figsize = (10, 6) def __init__(self, data, orientation='horizontal', sort_by='degree', sort_sets_by='cardinality', sum_over=None, facecolor='black', with_lines=True, element_size=32, intersection_plot_elements=6, totals_plot_elements=2, show_counts=''): self._horizontal = orientation == 'horizontal' self._reorient = _identity if self._horizontal else _transpose self._facecolor = facecolor self._with_lines = with_lines self._element_size = element_size self._totals_plot_elements = totals_plot_elements self._subset_plots = [{'type': 'default', 'id': 'intersections', 'elements': intersection_plot_elements}] self._show_counts = show_counts (self._df, self.intersections, self.totals) = _process_data(data, sort_by=sort_by, sort_sets_by=sort_sets_by, sum_over=sum_over) if not self._horizontal: self.intersections = self.intersections[::-1] def _swapaxes(self, x, y): if self._horizontal: return x, y return y, x
[docs] def add_catplot(self, kind, value=None, elements=3, **kw): """Add a seaborn catplot over subsets when :func:`plot` is called. Parameters ---------- kind : str One of {"point", "bar", "strip", "swarm", "box", "violin", "boxen"} value : str, optional Column name for the value to plot (i.e. y if orientation='horizontal'), required if `data` is a DataFrame. elements : int, default=3 Size of the axes counted in number of matrix elements. **kw : dict Additional keywords to pass to :func:`seaborn.catplot`. Our implementation automatically determines 'ax', 'data', 'x', 'y' and 'orient', so these are prohibited keys in `kw`. Returns ------- None """ assert not set(kw.keys()) & {'ax', 'data', 'x', 'y', 'orient'} if value is None: if '_value' not in self._df.columns: raise ValueError('value cannot be set if data is a Series. ' 'Got %r' % value) else: if value not in self._df.columns: raise ValueError('value %r is not a column in data' % value) self._subset_plots.append({'type': 'catplot', 'value': value, 'kind': kind, 'id': 'extra%d' % len(self._subset_plots), 'elements': elements, 'kw': kw})
def _plot_catplot(self, ax, value, kind, kw): df = self._df if value is None and '_value' in df.columns: value = '_value' elif value is None: raise ValueError('value can only be None when data is a Series') kw = kw.copy() if self._horizontal: kw['orient'] = 'v' kw['x'] = '_bin' kw['y'] = value else: kw['orient'] = 'h' kw['x'] = value kw['y'] = '_bin' import seaborn kw['ax'] = ax getattr(seaborn, kind + 'plot')(data=df, **kw) ax = self._reorient(ax) if value == '_value': ax.set_ylabel('') ax.xaxis.set_visible(False) for x in ['top', 'bottom', 'right']: ax.spines[self._reorient(x)].set_visible(False) tick_axis = ax.yaxis tick_axis.grid(True)
[docs] def make_grid(self, fig=None): """Get a SubplotSpec for each Axes, accounting for label text width """ n_cats = len(self.totals) n_inters = len(self.intersections) if fig is None: fig = plt.gcf() # Determine text size to determine figure size / spacing r = get_renderer(fig) t = fig.text(0, 0, '\n'.join(self.totals.index.values)) textw = t.get_window_extent(renderer=r).width t.remove() MAGIC_MARGIN = 10 # FIXME figw = self._reorient(fig.get_window_extent(renderer=r)).width sizes = np.asarray([p['elements'] for p in self._subset_plots]) if self._element_size is None: colw = (figw - textw - MAGIC_MARGIN) / (len(self.intersections) + self._totals_plot_elements) else: fig = self._reorient(fig) render_ratio = figw / fig.get_figwidth() colw = self._element_size / 72 * render_ratio figw = (colw * (len(self.intersections) + self._totals_plot_elements) + MAGIC_MARGIN + textw) fig.set_figwidth(figw / render_ratio) fig.set_figheight((colw * (n_cats + sizes.sum())) / render_ratio) text_nelems = int(np.ceil(figw / colw - (len(self.intersections) + self._totals_plot_elements))) GS = self._reorient(matplotlib.gridspec.GridSpec) gridspec = GS(*self._swapaxes(n_cats + sizes.sum(), n_inters + text_nelems + self._totals_plot_elements), hspace=1) if self._horizontal: out = {'matrix': gridspec[-n_cats:, -n_inters:], 'shading': gridspec[-n_cats:, :], 'totals': gridspec[-n_cats:, :self._totals_plot_elements], 'gs': gridspec} cumsizes = np.cumsum(sizes[::-1]) for start, stop, plot in zip(np.hstack([[0], cumsizes]), cumsizes, self._subset_plots[::-1]): out[plot['id']] = gridspec[start:stop, -n_inters:] else: out = {'matrix': gridspec[-n_inters:, :n_cats], 'shading': gridspec[:, :n_cats], 'totals': gridspec[:self._totals_plot_elements, :n_cats], 'gs': gridspec} cumsizes = np.cumsum(sizes) for start, stop, plot in zip(np.hstack([[0], cumsizes]), cumsizes, self._subset_plots): out[plot['id']] = gridspec[-n_inters:, start + n_cats:stop + n_cats] return out
[docs] def plot_matrix(self, ax): """Plot the matrix of intersection indicators onto ax """ ax = self._reorient(ax) data = self.intersections n_sets = data.index.nlevels idx = np.flatnonzero(data.index.to_frame()[data.index.names].values) c = np.array(['lightgrey'] * len(data) * n_sets, dtype='O') c[idx] = self._facecolor x = np.repeat(np.arange(len(data)), n_sets) y = np.tile(np.arange(n_sets), len(data)) if self._element_size is not None: s = (self._element_size * .35) ** 2 else: # TODO: make s relative to colw s = 200 ax.scatter(*self._swapaxes(x, y), c=c.tolist(), linewidth=0, s=s) if self._with_lines: line_data = (pd.Series(y[idx], index=x[idx]) .groupby(level=0) .aggregate(['min', 'max'])) ax.vlines(line_data.index.values, line_data['min'], line_data['max'], lw=2, colors=self._facecolor) tick_axis = ax.yaxis tick_axis.set_ticks(np.arange(n_sets)) tick_axis.set_ticklabels(data.index.names, rotation=0 if self._horizontal else -90) ax.xaxis.set_visible(False) ax.tick_params(axis='both', which='both', length=0) if not self._horizontal: ax.yaxis.set_ticks_position('top') ax.set_frame_on(False)
[docs] def plot_intersections(self, ax): """Plot bars indicating intersection size """ ax = self._reorient(ax) rects = ax.bar(np.arange(len(self.intersections)), self.intersections, .5, color=self._facecolor, zorder=10, align='center') self._label_sizes(ax, rects, 'top' if self._horizontal else 'right') ax.xaxis.set_visible(False) for x in ['top', 'bottom', 'right']: ax.spines[self._reorient(x)].set_visible(False) tick_axis = ax.yaxis tick_axis.grid(True) # FIXME: doesn't seem to display tick_axis.set_label('Intersection size')
# tick_axis.set_tick_params(direction='in') def _label_sizes(self, ax, rects, where): if not self._show_counts: return fmt = '%d' if self._show_counts is True else self._show_counts if where == 'right': margin = 0.01 * abs(np.diff(ax.get_xlim())) for rect in rects: width = rect.get_width() ax.text(width + margin, rect.get_y() + rect.get_height() * .5, fmt % width, ha='left', va='center') elif where == 'left': margin = 0.01 * abs(np.diff(ax.get_xlim())) for rect in rects: width = rect.get_width() ax.text(width + margin, rect.get_y() + rect.get_height() * .5, fmt % width, ha='right', va='center') elif where == 'top': margin = 0.01 * abs(np.diff(ax.get_ylim())) for rect in rects: height = rect.get_height() ax.text(rect.get_x() + rect.get_width() * .5, height + margin, fmt % height, ha='center', va='bottom') else: raise NotImplementedError('unhandled where: %r' % where)
[docs] def plot_totals(self, ax): """Plot bars indicating total set size """ orig_ax = ax ax = self._reorient(ax) rects = ax.barh(np.arange(len(self.totals.index.values)), self.totals, .5, color=self._facecolor, align='center') self._label_sizes(ax, rects, 'left' if self._horizontal else 'top') max_total = self.totals.max() if self._horizontal: orig_ax.set_xlim(max_total, 0) for x in ['top', 'left', 'right']: ax.spines[self._reorient(x)].set_visible(False) ax.yaxis.set_visible(False) ax.xaxis.grid(True) ax.patch.set_visible(False)
def plot_shading(self, ax): # alternating row shading (XXX: use add_patch(Rectangle)?) for i in range(0, len(self.totals), 2): rect = plt.Rectangle(self._swapaxes(0, i - .4), *self._swapaxes(*(1, .8)), facecolor='#f5f5f5', lw=0, zorder=0) ax.add_patch(rect) ax.set_frame_on(False) ax.tick_params( axis='both', which='both', left=False, right=False, bottom=False, top=False, labelbottom=False, labelleft=False) ax.set_xticks([]) ax.set_yticks([]) ax.set_xticklabels([]) ax.set_yticklabels([])
[docs] def plot(self, fig=None): """Draw all parts of the plot onto fig or a new figure Parameters ---------- fig : matplotlib.figure.Figure, optional Defaults to a new figure. Returns ------- subplots : dict of matplotlib.axes.Axes Keys are 'matrix', 'intersections', 'totals', 'shading' """ if fig is None: fig = plt.figure(figsize=self._default_figsize) specs = self.make_grid(fig) shading_ax = fig.add_subplot(specs['shading']) self.plot_shading(shading_ax) matrix_ax = self._reorient(fig.add_subplot)(specs['matrix'], sharey=shading_ax) self.plot_matrix(matrix_ax) totals_ax = self._reorient(fig.add_subplot)(specs['totals'], sharey=matrix_ax) self.plot_totals(totals_ax) out = {'matrix': matrix_ax, 'shading': shading_ax, 'totals': totals_ax} for plot in self._subset_plots: ax = self._reorient(fig.add_subplot)(specs[plot['id']], sharex=matrix_ax) if plot['type'] == 'default': self.plot_intersections(ax) elif plot['type'] == 'catplot': self._plot_catplot(ax, plot['value'], plot['kind'], plot['kw']) else: raise ValueError('Unknown subset plot type: %r' % plot['type']) out[plot['id']] = ax return out
def _repr_html_(self): fig = plt.figure(figsize=self._default_figsize) self.plot(fig=fig) return fig._repr_html_()
[docs]def plot(data, fig=None, **kwargs): """Make an UpSet plot of data on fig Parameters ---------- data : pandas.Series or pandas.DataFrame Values for each set to plot. Should have multi-index where each level is binary, corresponding to set membership. If a DataFrame, `sum_over` must be a string or False. fig : matplotlib.figure.Figure, optional Defaults to a new figure. kwargs Other arguments for :class:`UpSet` Returns ------- subplots : dict of matplotlib.axes.Axes Keys are 'matrix', 'intersections', 'totals', 'shading' """ return UpSet(data, **kwargs).plot(fig)