From b591b2c9198f670d172975c4043f9d9240642e4b Mon Sep 17 00:00:00 2001 From: Alan Pawlak Date: Thu, 19 Jan 2023 10:29:49 +0000 Subject: [PATCH] Add support for passing 'ax' parameter to function for plotting on specific axes (matplotlib) --- src/sphstat/plotting.py | 40 ++++++++++++++++++++++++++++++---------- 1 file changed, 30 insertions(+), 10 deletions(-) diff --git a/src/sphstat/plotting.py b/src/sphstat/plotting.py index 74eafd1..def2e34 100644 --- a/src/sphstat/plotting.py +++ b/src/sphstat/plotting.py @@ -34,6 +34,7 @@ """ import numpy as np +import matplotlib from matplotlib import pyplot as plt from .descriptives import mediandir, rotationmatrix, pointsonanellipse @@ -70,7 +71,7 @@ def plotmapping(input: list): return output -def plotdata(sample: dict, proj: str='mollweide', mflag: bool = False) -> bool: +def plotdata(sample: dict, proj: str='mollweide', mflag: bool = False, ax: matplotlib.axes.Axes=None) -> bool: """ Plot a sample represented in polar (colat, long) format in Mollweide projection @@ -80,6 +81,8 @@ def plotdata(sample: dict, proj: str='mollweide', mflag: bool = False) -> bool: :type proj: str :param mflag: Flag to plot the median and the 95% cone of confidence :type mflag: bool + :param ax: Axes to plot on + :type ax: matplotlib.axes.Axes :return: bool (True) """ try: @@ -92,8 +95,12 @@ def plotdata(sample: dict, proj: str='mollweide', mflag: bool = False) -> bool: except AssertionError: raise AssertionError('Unknown projection type!') - fig = plt.figure(figsize=(10, 5)) - ax = fig.add_subplot(111, projection=proj) + if ax is None: + show_flag = True + fig = plt.figure(figsize=(10, 5)) + ax = fig.add_subplot(111, projection=proj) + else: + show_flag = False phis = plotmapping(sample['phis']) ax.scatter(np.array(phis), np.pi / 2 - np.array(sample['tetas']), s=1.5 * plt.rcParams['lines.markersize'] ** 1.5, @@ -139,11 +146,15 @@ def plotdata(sample: dict, proj: str='mollweide', mflag: bool = False) -> bool: ax.set_ylabel("Colatitude [deg]") ax.yaxis.label.set_fontsize(12) ax.grid(True) - plt.show() + + if show_flag: + plt.show() + return True -def plotdatalist(samplelist: list, labels: list=None, proj: str='mollweide', mflag: bool = False) -> bool: +def plotdatalist(samplelist: list, labels: list=None, proj: str='mollweide', mflag: bool = False, + ax: matplotlib.axes.Axes=None) -> bool: """ Superimposed plot of a list of samples @@ -155,18 +166,24 @@ def plotdatalist(samplelist: list, labels: list=None, proj: str='mollweide', mfl :type proj: str :param mflag: Flag to plot the median and the 95% cone of confidence :type mflag: bool + :param ax: Axes to plot on + :type ax: matplotlib.axes.Axes :return: Return True when completed :rtype: bool """ - fig = plt.figure(figsize=(10, 5)) - ax = fig.add_subplot(111, projection=proj) - try: assert len(labels) == len(samplelist) except AssertionError: raise AssertionError('Number of labels should match the number of samples in samplelist') + if ax is None: + show_flag = True + fig = plt.figure(figsize=(10, 5)) + ax = fig.add_subplot(111, projection=proj) + else: + show_flag = False + ind = -1 for sample in samplelist: ind += 1 @@ -219,6 +236,9 @@ def plotdatalist(samplelist: list, labels: list=None, proj: str='mollweide', mfl ax.set_ylabel('Colatitude [deg]') ax.yaxis.label.set_fontsize(16) ax.grid(True) - plt.legend(fontsize='large', loc=1) - plt.show() + + if show_flag: + plt.legend(fontsize='large', loc=1) + plt.show() + return True