Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/api/docstrings/pydeseq2.dds.DeseqDataSet.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
~DeseqDataSet.fit_genewise_dispersions
~DeseqDataSet.fit_size_factors
~DeseqDataSet.plot_dispersions
~DeseqDataSet.plot_rle
~DeseqDataSet.refit
~DeseqDataSet.vst

1 change: 1 addition & 0 deletions docs/source/api/docstrings/pydeseq2.ds.DeseqStats.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
~DeseqStats.lfc_shrink
~DeseqStats.run_wald_test
~DeseqStats.summary
~DeseqStats.plot_MA



Expand Down
31 changes: 31 additions & 0 deletions pydeseq2/dds.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from pydeseq2.preprocessing import deseq2_norm_transform
from pydeseq2.utils import build_design_matrix
from pydeseq2.utils import dispersion_trend
from pydeseq2.utils import make_rle_plot
from pydeseq2.utils import make_scatter
from pydeseq2.utils import mean_absolute_deviation
from pydeseq2.utils import n_or_more_replicates
Expand Down Expand Up @@ -1010,6 +1011,36 @@ def plot_dispersions(
**kwargs,
)

def plot_rle(
self,
normalize: bool = False,
save_path: str | None = None,
**kwargs,
):
"""Plot ratio of log expressions (RLE) for each sample.

Useful for visualizing sample to sample variation.

Parameters
----------
normalize : bool, optional
Whether to normalize the counts before plotting. (default: ``False``).

save_path : str, optional
The path where to save the plot. If left None, the plot won't be saved
(default: ``None``).

**kwargs
Keyword arguments for the scatter plot.
"""
make_rle_plot(
count_matrix=self.X,
normalize=normalize,
sample_ids=self.obs_names,
save_path=save_path,
**kwargs,
)

def _fit_parametric_dispersion_trend(self, vst: bool = False):
r"""Fit the dispersion curve according to a parametric model.

Expand Down
60 changes: 60 additions & 0 deletions pydeseq2/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1608,3 +1608,63 @@ def lowess(
delta = (1 - delta**2) ** 2

return yest


def make_rle_plot(
count_matrix: np.array,
sample_ids: np.array,
normalize: bool = False,
save_path: str | None = None,
**kwargs,
) -> None:
"""
Create a ratio of log expression plot using matplotlib.

Parameters
----------
count_matrix : ndarray
An mxn matrix of count data, where m is the number of samples (rows),
and n is the number of genes (columns).

sample_ids : ndarray
An array of sample identifiers.

normalize : bool
Whether to normalize the count matrix before plotting. (default: ``False``).

save_path : str, optional
The path where to save the plot. If left None, the plot won't be saved
(default: ``None``).

**kwargs :
Additional keyword arguments passed to matplotlib's boxplot function.
"""
if normalize:
geometric_mean = np.exp(np.mean(np.log(count_matrix + 1), axis=0))
size_factors = np.median(count_matrix / geometric_mean, axis=1)
count_matrix = count_matrix / size_factors[:, np.newaxis]

plt.rcParams.update({"font.size": 10})

fig, ax = plt.subplots(figsize=(15, 8), dpi=600)

# Calculate median expression across samples
gene_medians = np.median(count_matrix, axis=0)
rle_values = np.log2(count_matrix / gene_medians)

kwargs.setdefault("alpha", 0.5)
boxprops = {"facecolor": "lightgray", "alpha": kwargs.pop("alpha")}

ax.boxplot(rle_values.T, patch_artist=True, boxprops=boxprops, **kwargs)

ax.axhline(0, color="red", linestyle="--", linewidth=1, alpha=0.5, zorder=3)
ax.set_xlabel("Sample")
ax.set_ylabel("Relative Log Expression")
ax.set_xticks(np.arange(len(sample_ids)))
ax.set_xticklabels(sample_ids, rotation=90)
plt.tight_layout()

if save_path:
plt.savefig(save_path, bbox_inches="tight")
else:
plt.show()
13 changes: 13 additions & 0 deletions tests/test_pydeseq2.py
Original file line number Diff line number Diff line change
Expand Up @@ -875,3 +875,16 @@ def assert_res_almost_equal(py_res, r_res, tol=0.02):
).max() < tol
assert (abs(r_res.pvalue - py_res.pvalue) / r_res.pvalue).max() < tol
assert (abs(r_res.padj - py_res.padj) / r_res.padj).max() < tol


def test_plot_rle(train_counts, train_metadata):
"""Test that the RLE plot is generated without error."""

dds = DeseqDataSet(
counts=train_counts,
metadata=train_metadata,
design="~condition",
)

dds.plot_rle(normalize=False)
dds.plot_rle(normalize=True)
Loading