diff --git a/symbulate/distributions.py b/symbulate/distributions.py index b5670d7..f6c6f51 100644 --- a/symbulate/distributions.py +++ b/symbulate/distributions.py @@ -85,7 +85,8 @@ def plot(self, xlim=None, alpha=None, ax=None, **kwargs): ax.set_ylim(*ylim) # get next color in cycle - color = get_next_color(ax) + + color = get_next_color() # plot points for discrete distributions if self.discrete: diff --git a/symbulate/plot.py b/symbulate/plot.py index 6d76953..cf8f398 100644 --- a/symbulate/plot.py +++ b/symbulate/plot.py @@ -12,13 +12,17 @@ xlim = plt.xlim ylim = plt.ylim +color_index=0 +color_cycle = [c['color'] for c in plt.rcParams['axes.prop_cycle']] + def init_color(): hex_list = [colors.rgb2hex(rgb) for rgb in plt.cm.get_cmap('tab10').colors] plt.rcParams["axes.prop_cycle"] = cycler('color', hex_list) -def get_next_color(axes): - color_cycle = axes._get_lines.prop_cycler - color = next(color_cycle)["color"] +def get_next_color(): + global color_index + color = color_cycle[color_index] + color_index = (color_index + 1) % len(color_cycle) return color def configure_axes(axes, xdata, ydata, xlabel = None, ylabel = None): diff --git a/symbulate/results.py b/symbulate/results.py index 0f54b0f..96345cf 100644 --- a/symbulate/results.py +++ b/symbulate/results.py @@ -473,7 +473,7 @@ def plot(self, type=None, alpha=None, normalize=True, jitter=False, # initialize figure fig = plt.gcf() ax = plt.gca() - color = get_next_color(ax) + color = get_next_color() if 'density' in type: if discrete: @@ -545,27 +545,27 @@ def plot(self, type=None, alpha=None, normalize=True, jitter=False, x_lines = np.linspace(min(x), max(x), 1000) y_lines = np.linspace(min(y), max(y), 1000) ax_marg_x.plot(x_lines, densityX(x_lines), linewidth=2, - color=get_next_color(ax)) + color=get_next_color()) ax_marg_y.plot(y_lines, densityY(y_lines), linewidth=2, - color=get_next_color(ax), + color=get_next_color(), transform=Affine2D().rotate_deg(270) + ax_marg_y.transData) else: if discrete_x: - make_marginal_impulse(x_count, get_next_color(ax), ax_marg_x, alpha, 'x') + make_marginal_impulse(x_count, get_next_color(), ax_marg_x, alpha, 'x') else: - ax_marg_x.hist(x, color=get_next_color(ax), density=normalize, + ax_marg_x.hist(x, color=get_next_color(), density=normalize, alpha=alpha, bins=bins) if discrete_y: - make_marginal_impulse(y_count, get_next_color(ax), ax_marg_y, alpha, 'y') + make_marginal_impulse(y_count, get_next_color(), ax_marg_y, alpha, 'y') else: - ax_marg_y.hist(y, color=get_next_color(ax), density=normalize, + ax_marg_y.hist(y, color=get_next_color(), density=normalize, alpha=alpha, bins=bins, orientation='horizontal') plt.setp(ax_marg_x.get_xticklabels(), visible=False) plt.setp(ax_marg_y.get_yticklabels(), visible=False) else: fig = plt.gcf() ax = plt.gca() - color = get_next_color(ax) + color = get_next_color() nullfmt = NullFormatter() #removes labels on fig @@ -605,7 +605,7 @@ def plot(self, type=None, alpha=None, normalize=True, jitter=False, if alpha is None: alpha = np.log(2) / np.log(len(self) + 1) ax = plt.gca() - color = get_next_color(ax) + color = get_next_color() for result in self.results: result.plot(alpha=alpha, color=color, **kwargs) plt.xlabel("Index")