diff --git a/src/pathpyG/visualisations/_d3js/backend.py b/src/pathpyG/visualisations/_d3js/backend.py index 76c56bc7..cf317d4c 100644 --- a/src/pathpyG/visualisations/_d3js/backend.py +++ b/src/pathpyG/visualisations/_d3js/backend.py @@ -21,7 +21,6 @@ import uuid import webbrowser from copy import deepcopy -from string import Template from pathpyG.utils.config import config from pathpyG.visualisations.network_plot import NetworkPlot @@ -39,7 +38,7 @@ TemporalNetworkPlot: "temporal", TimeUnfoldedNetworkPlot: "unfolded", } -_CDN_URL = "https://d3js.org/d3.v7.min.js" +_CDN_URL = "https://cdn.jsdelivr.net/npm/d3@7/+esm" class D3jsBackend(PlotBackend): @@ -116,8 +115,8 @@ def save(self, filename: str) -> None: - Embedded in websites or documentation - Shared without additional dependencies """ - # Default to the CDN version of d3js since browsers may block local scripts - self.config["d3js_local"] = self.config.get("d3js_local", False) + # Default to embedded local version to obtain a self-contained file + self.config["d3js_local"] = config.get("d3js_local", True) with open(filename, "w+") as new: new.write(self.to_html()) @@ -133,13 +132,13 @@ def show(self) -> None: and choose appropriate display method automatically. """ # Default to CDN version if reachable - # Check if CDN is reachable try: - urllib.request.urlopen(_CDN_URL, timeout=2) - self.config["d3js_local"] = self.config.get("d3js_local", False) + # Attempt to access the CDN URL to check if it's reachable + urllib.request.urlopen(urllib.request.Request(_CDN_URL, headers={"User-Agent": "Mozilla/5.0"}), timeout=2) + self.config["d3js_local"] = config.get("d3js_local", False) except (urllib.error.URLError, urllib.error.HTTPError): - self.config["d3js_local"] = self.config.get("d3js_local", True) - + self.config["d3js_local"] = config.get("d3js_local", True) + if config["environment"]["interactive"]: from IPython.display import display_html, HTML # noqa I001 @@ -168,15 +167,21 @@ def _prepare_data(self) -> dict: **Edges**: Include uid, source/target references, and styling """ node_data = self.data["nodes"].copy() - node_data["uid"] = self.data["nodes"].index.map(lambda x: f"({x[0]},{x[1]})" if isinstance(x, tuple) else str(x)) + node_data["uid"] = self.data["nodes"].index.map( + lambda x: f"({x[0]},{x[1]})" if isinstance(x, tuple) else str(x) + ) node_data = node_data.rename(columns={"x": "xpos", "y": "ypos"}) if self._kind == "unfolded": node_data["ypos"] = 1 - node_data["ypos"] # Invert y-axis for unfolded layout edge_data = self.data["edges"].copy() edge_data["uid"] = self.data["edges"].index.map(lambda x: f"{x[0]}-{x[1]}") if len(edge_data) > 0: - edge_data["source"] = edge_data.index.to_frame()["source"].map(lambda x: f"({x[0]},{x[1]})" if isinstance(x, tuple) else str(x)) - edge_data["target"] = edge_data.index.to_frame()["target"].map(lambda x: f"({x[0]},{x[1]})" if isinstance(x, tuple) else str(x)) + edge_data["source"] = edge_data.index.to_frame()["source"].map( + lambda x: f"({x[0]},{x[1]})" if isinstance(x, tuple) else str(x) + ) + edge_data["target"] = edge_data.index.to_frame()["target"].map( + lambda x: f"({x[0]},{x[1]})" if isinstance(x, tuple) else str(x) + ) data_dict = { "nodes": node_data.to_dict(orient="records"), "edges": edge_data.to_dict(orient="records"), @@ -253,17 +258,8 @@ def to_html(self) -> str: os.path.normpath("_d3js/templates"), ) - # get d3js library path - if self.config.get("d3js_local", False): - d3js = os.path.join(template_dir, "d3.v7.min.js") - else: - d3js = _CDN_URL - js_template = self.get_template(template_dir) - with open(os.path.join(template_dir, "setup.js")) as template: - setup_template = template.read() - with open(os.path.join(template_dir, "styles.css")) as template: css_template = template.read() @@ -277,17 +273,16 @@ def to_html(self) -> str: # div environment for the plot object html += f'\n