Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "topic-wizard"
version = "1.1.1"
version = "1.1.2"
description = "Pretty and opinionated topic model visualization in Python."
authors = ["Márton Kardos <power.up1163@gmail.com>"]
license = "MIT"
Expand All @@ -13,7 +13,7 @@ dash = "^2.7.1"
dash-extensions = "^1.0.4"
dash-mantine-components = "~0.12.1"
dash-iconify = "~0.1.2"
joblib = "~1.2.0"
joblib = "^1.2.0"
scikit-learn = "^1.2.0"
scipy = ">=1.8.0"
umap-learn = ">=0.5.3"
Expand Down
20 changes: 16 additions & 4 deletions topicwizard/figures/documents.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""External API for creating self-contained figures for documents."""

from typing import List, Optional, Union

import numpy as np
Expand Down Expand Up @@ -56,7 +57,10 @@ def document_map(


def document_topic_distribution(
topic_data: TopicData, documents: Union[List[str], str], top_n: int = 8
topic_data: TopicData,
documents: Union[List[str], str],
top_n: int = 8,
color_scheme: str = "Portland",
) -> go.Figure:
"""Displays topic distribution on a bar plot for a document
or a set of documents.
Expand All @@ -69,6 +73,8 @@ def document_topic_distribution(
Documents to display topic distribution for.
top_n: int, default 8
Number of topics to display at most.
color_scheme: str, default 'Portland'
Name of the Plotly color scheme to use for the plot.
"""
transform = topic_data["transform"]
if transform is None:
Expand All @@ -80,7 +86,7 @@ def document_topic_distribution(
topic_importances = prepare.document_topic_importances(transform(documents))
topic_importances = topic_importances.groupby(["topic_id"]).sum().reset_index()
n_topics = topic_data["document_topic_matrix"].shape[-1]
twilight = colors.get_colorscale("Portland")
twilight = colors.get_colorscale(color_scheme)
topic_colors = colors.sample_colorscale(twilight, np.arange(n_topics) / n_topics)
topic_colors = np.array(topic_colors)
return plots.document_topic_barplot(
Expand All @@ -89,7 +95,11 @@ def document_topic_distribution(


def document_topic_timeline(
topic_data: TopicData, document: str, window_size: int = 10, step_size: int = 1
topic_data: TopicData,
document: str,
window_size: int = 10,
step_size: int = 1,
color_scheme: str = "Portland",
) -> go.Figure:
"""Projects documents into 2d space and displays them on a scatter plot.

Expand All @@ -103,6 +113,8 @@ def document_topic_timeline(
The windows over which topic inference should be run.
step_size: int, default 1
Size of the steps for the rolling window.
color_scheme: str, default 'Portland'
Name of Plotly color scheme to use for the plot.
"""
timeline = prepare.calculate_timeline(
doc_id=0,
Expand All @@ -113,7 +125,7 @@ def document_topic_timeline(
)
topic_names = topic_data["topic_names"]
n_topics = len(topic_names)
twilight = colors.get_colorscale("Portland")
twilight = colors.get_colorscale(color_scheme)
topic_colors = colors.sample_colorscale(twilight, np.arange(n_topics) / n_topics)
topic_colors = np.array(topic_colors)
return plots.document_timeline(timeline, topic_names, topic_colors)
11 changes: 9 additions & 2 deletions topicwizard/figures/groups.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""External API for creating self-contained figures for groups."""

from typing import List

import numpy as np
Expand Down Expand Up @@ -68,7 +69,11 @@ def group_map(topic_data: TopicData, group_labels: List[str]) -> go.Figure:


def group_topic_barcharts(
topic_data: TopicData, group_labels: List[str], top_n: int = 5, n_columns: int = 4
topic_data: TopicData,
group_labels: List[str],
top_n: int = 5,
n_columns: int = 4,
color_scheme: str = "Portland",
):
"""Displays the most important topics for each group.

Expand All @@ -82,6 +87,8 @@ def group_topic_barcharts(
Maximum number of topics to display for each group.
n_columns: int, default 4
Indicates how many columns the faceted plot should have.
color_scheme: str, default 'Portland'
Name of the plotly color scheme to use for the figure.
"""
# Factorizing group labels
group_id_labels, group_names = pd.factorize(group_labels)
Expand All @@ -105,7 +112,7 @@ def group_topic_barcharts(
horizontal_spacing=0.01,
)
n_topics = len(topic_data["topic_names"])
color_scheme = colors.get_colorscale("Portland")
color_scheme = colors.get_colorscale(color_scheme)
topic_colors = colors.sample_colorscale(
color_scheme, np.arange(n_topics) / n_topics, low=0.25, high=1.0
)
Expand Down
5 changes: 4 additions & 1 deletion topicwizard/figures/topics.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ def topic_wordclouds(
topic_data: TopicData,
top_n: int = 30,
n_columns: int = 4,
color_scheme: str = "copper",
) -> go.Figure:
"""Plots most relevant words as word clouds for every topic.

Expand All @@ -121,6 +122,8 @@ def topic_wordclouds(
Specifies the number of words to show for each topic.
n_columns: int, default 4
Number of columns in the subplot grid.
color_scheme: str, default 'copper'
Matplotlib color scheme to use for the wordcloud.
"""
n_topics = topic_data["topic_term_matrix"].shape[0]
(
Expand All @@ -147,7 +150,7 @@ def topic_wordclouds(
components=topic_term_importances,
vocab=topic_data["vocab"],
)
subfig = plots.wordcloud(top_words)
subfig = plots.wordcloud(top_words, color_scheme=color_scheme)
row, column = (topic_id // n_columns) + 1, (topic_id % n_columns) + 1
fig.add_trace(subfig.data[0], row=row, col=column)
fig.update_layout(
Expand Down
10 changes: 8 additions & 2 deletions topicwizard/figures/words.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def word_map(
topic_data: TopicData,
z_threshold: float = 2.0,
topic_axes: Optional[Tuple[Union[str, int], Union[str, int]]] = None,
color_scheme: str = "tempo",
) -> go.Figure:
"""Plots words on a scatter plot based on UMAP projections
of their importances in topics into 2D space or by two topic axes.
Expand All @@ -38,6 +39,8 @@ def word_map(
The topic axes along which the words should be displayed.
If not specified, the axes on the graph are going to be
UMAP projections' dimensions.
color_scheme: str, default 'tempo'
Name of the Plotly color scheme to use for the plot.
"""
topic_names = topic_data["topic_names"]
if topic_axes is None:
Expand All @@ -58,7 +61,7 @@ def word_map(
freq_z = zscore(word_frequencies)
dominant_topic = prepare.dominant_topic(topic_data["topic_term_matrix"])
dominant_topic = np.array(topic_data["topic_names"])[dominant_topic]
tempo = colors.get_colorscale("tempo")
tempo = colors.get_colorscale(color_scheme)
n_topics = len(topic_data["topic_names"])
topic_colors = colors.sample_colorscale(tempo, np.arange(n_topics) / n_topics)
topic_colors = np.array(topic_colors)
Expand Down Expand Up @@ -99,6 +102,7 @@ def word_association_barchart(
words: Union[List[str], str],
n_association: int = 0,
top_n: int = 20,
color_scheme: str = "Rainbow",
):
"""Plots bar chart of most important topics for the given words and their closest
associations in topic space.
Expand All @@ -114,6 +118,8 @@ def word_association_barchart(
None get displayed by default.
top_n: int = 20
Top N topics to display.
color_scheme: str, default 'Rainbow'
Name of the Plotly color scheme to use for the plot.
"""
if isinstance(words, str):
words = [words]
Expand All @@ -128,7 +134,7 @@ def word_association_barchart(
word_ids, topic_data["topic_term_matrix"], n_association
)
n_topics = topic_data["topic_term_matrix"].shape[0]
tempo = colors.get_colorscale("Rainbow")
tempo = colors.get_colorscale(color_scheme)
topic_colors = colors.sample_colorscale(tempo, np.arange(n_topics) / n_topics)
topic_colors = np.array(topic_colors)
top_topics = prepare.top_topics(
Expand Down
5 changes: 3 additions & 2 deletions topicwizard/plots/topics.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Module containing plotting utilities for topics."""

from typing import List

import numpy as np
Expand Down Expand Up @@ -139,7 +140,7 @@ def topic_plot(top_words: pd.DataFrame):
return fig


def wordcloud(top_words: pd.DataFrame) -> go.Figure:
def wordcloud(top_words: pd.DataFrame, color_scheme: str = "copper") -> go.Figure:
"""Plots most relevant words for current topic as a worcloud."""
top_dict = {
word: importance
Expand All @@ -151,7 +152,7 @@ def wordcloud(top_words: pd.DataFrame) -> go.Figure:
width=800,
height=1060,
background_color="white",
colormap="copper",
colormap=color_scheme,
scale=4,
).generate_from_frequencies(top_dict)
image = cloud.to_image()
Expand Down
Loading