diff --git a/crates/amalthea/src/comm/plot_comm.rs b/crates/amalthea/src/comm/plot_comm.rs index efa3cc813..85d09d4a5 100644 --- a/crates/amalthea/src/comm/plot_comm.rs +++ b/crates/amalthea/src/comm/plot_comm.rs @@ -27,6 +27,22 @@ pub struct IntrinsicSize { pub source: String } +/// The plot's metadata +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] +pub struct PlotMetadata { + /// A human-readable name for the plot + pub name: String, + + /// The kind of plot e.g. 'Matplotlib', 'ggplot2', etc. + pub kind: String, + + /// The ID of the code fragment that produced the plot + pub execution_id: String, + + /// The code fragment that produced the plot + pub code: String +} + /// A rendered plot #[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] pub struct PlotResult { @@ -133,6 +149,12 @@ pub enum PlotBackendRequest { #[serde(rename = "get_intrinsic_size")] GetIntrinsicSize, + /// Get metadata for the plot + /// + /// Get metadata for the plot + #[serde(rename = "get_metadata")] + GetMetadata, + /// Render a plot /// /// Requests a plot to be rendered. The plot data is returned in a @@ -151,6 +173,9 @@ pub enum PlotBackendReply { /// The intrinsic size of a plot, if known GetIntrinsicSizeReply(Option), + /// The plot's metadata + GetMetadataReply(PlotMetadata), + /// A rendered plot RenderReply(PlotResult), diff --git a/crates/ark/src/interface.rs b/crates/ark/src/interface.rs index 2531a35c4..77903eee2 100644 --- a/crates/ark/src/interface.rs +++ b/crates/ark/src/interface.rs @@ -876,6 +876,14 @@ impl RMain { &self.iopub_tx } + /// Get the current execution context if an active request exists. + /// Returns (execution_id, code) tuple where execution_id is the Jupyter message ID. + pub fn get_execution_context(&self) -> Option<(String, String)> { + self.active_request + .as_ref() + .map(|req| (req.originator.header.msg_id.clone(), req.request.code.clone())) + } + fn init_execute_request(&mut self, req: &ExecuteRequest) -> (ConsoleInput, u32) { // Reset the autoprint buffer self.autoprint_output = String::new(); @@ -1306,11 +1314,17 @@ impl RMain { // Save `ExecuteCode` request so we can respond to it at next prompt self.active_request = Some(ActiveReadConsoleRequest { exec_count, - request: exec_req, - originator, + request: exec_req.clone(), + originator: originator.clone(), reply_tx, }); + // Push execution context to graphics device for plot attribution + graphics_device::on_execute_request( + originator.header.msg_id.clone(), + exec_req.code.clone(), + ); + input }, diff --git a/crates/ark/src/modules/positron/graphics.R b/crates/ark/src/modules/positron/graphics.R index 7b9132884..fe1c4cc9f 100644 --- a/crates/ark/src/modules/positron/graphics.R +++ b/crates/ark/src/modules/positron/graphics.R @@ -529,3 +529,168 @@ render_path <- function(id, format) { file <- paste0("render-", id, ".", format) file.path(directory, file) } + +#' Detect the kind of plot from a recording +#' +#' Uses multiple strategies to determine plot type: +#' 1. Check .Last.value for high-level plot objects (ggplot2, lattice) +#' 2. Check recording's display list for base graphics patterns +#' 3. Fall back to generic "plot" +#' +#' @param id The plot ID +#' @return A string describing the plot kind +#' @export +.ps.graphics.detect_plot_kind <- function(id) { + # Strategy 1: Check .Last.value for recognizable plot objects + # This works for ggplot2, lattice, and some other packages + value <- tryCatch( + get(".Last.value", envir = globalenv()), + error = function(e) NULL + ) + + if (!is.null(value)) { + kind <- detect_kind_from_value(value) + if (!is.null(kind)) { + return(kind) + } + } + + # Strategy 2: Check the recording itself + recording <- get_recording(id) + if (!is.null(recording)) { + # recordPlot() stores display list in first element + dl <- recording[[1]] + if (length(dl) > 0) { + kind <- detect_kind_from_display_list(dl) + if (!is.null(kind)) { + return(kind) + } + } + } + + # Default fallback + "plot" +} + +# Detect plot kind from .Last.value +# Returns plot kind string or NULL +detect_kind_from_value <- function(value) { + # ggplot2 + if (inherits(value, "ggplot")) { + return(detect_ggplot_kind(value)) + } + + # lattice + if (inherits(value, "trellis")) { + # Extract lattice plot type from call + call_fn <- as.character(value$call[[1]]) + kind_map <- c( + "xyplot" = "scatter plot", + "bwplot" = "box plot", + "histogram" = "histogram", + "densityplot" = "density plot", + "barchart" = "bar chart", + "dotplot" = "dot plot", + "levelplot" = "heatmap", + "contourplot" = "contour plot", + "cloud" = "3D scatter", + "wireframe" = "3D surface" + ) + if (call_fn %in% names(kind_map)) { + return(paste0("lattice ", kind_map[call_fn])) + } + return("lattice") + } + + # Base R objects that have class + if (inherits(value, "histogram")) { + return("histogram") + } + if (inherits(value, "density")) { + return("density") + } + if (inherits(value, "hclust")) { + return("dendrogram") + } + if (inherits(value, "acf")) { + return("autocorrelation") + } + + NULL +} + +# Detect ggplot2 plot kind from geom layers +# Returns plot kind string +detect_ggplot_kind <- function(gg) { + if (length(gg$layers) == 0) { + return("ggplot2") + } + + # Get the first layer's geom class + geom_class <- class(gg$layers[[1]]$geom)[1] + geom_name <- tolower(gsub("^Geom", "", geom_class)) + + kind_map <- c( + "point" = "scatter plot", + "line" = "line chart", + "bar" = "bar chart", + "col" = "bar chart", + "histogram" = "histogram", + "boxplot" = "box plot", + "violin" = "violin plot", + "density" = "density plot", + "area" = "area chart", + "tile" = "heatmap", + "raster" = "raster", + "contour" = "contour plot", + "smooth" = "smoothed line", + "text" = "text", + "label" = "labels", + "path" = "path", + "polygon" = "polygon", + "ribbon" = "ribbon", + "segment" = "segments", + "abline" = "reference lines", + "hline" = "horizontal lines", + "vline" = "vertical lines" + ) + + if (geom_name %in% names(kind_map)) { + return(paste0("ggplot2 ", kind_map[geom_name])) + } + + "ggplot2" +} + +# Detect plot kind from display list (base graphics) +# Returns plot kind string or NULL +detect_kind_from_display_list <- function(dl) { + # Display list entries are lists where first element is the C function name + call_names <- vapply(dl, function(x) { + if (is.list(x) && length(x) > 0) { + name <- x[[1]] + if (is.character(name)) name else "" + } else { + "" + } + }, character(1)) + + # Base graphics C functions to plot types + if (any(call_names == "C_plotHist")) return("histogram") + if (any(call_names == "C_image")) return("image") + if (any(call_names == "C_contour")) return("contour") + if (any(call_names == "C_persp")) return("3D surface") + if (any(call_names == "C_filledcontour")) return("filled contour") + + # Check for grid graphics (ggplot2, lattice) + if (any(grepl("^L_", call_names))) { + return("grid") + } + + # Check for base graphics + if (any(grepl("^C_", call_names))) { + return("base") + } + + NULL +} diff --git a/crates/ark/src/plots/graphics_device.rs b/crates/ark/src/plots/graphics_device.rs index c1ba1b5c0..051db139c 100644 --- a/crates/ark/src/plots/graphics_device.rs +++ b/crates/ark/src/plots/graphics_device.rs @@ -19,6 +19,7 @@ use amalthea::comm::event::CommManagerEvent; use amalthea::comm::plot_comm::PlotBackendReply; use amalthea::comm::plot_comm::PlotBackendRequest; use amalthea::comm::plot_comm::PlotFrontendEvent; +use amalthea::comm::plot_comm::PlotMetadata; use amalthea::comm::plot_comm::PlotRenderFormat; use amalthea::comm::plot_comm::PlotRenderSettings; use amalthea::comm::plot_comm::PlotResult; @@ -121,6 +122,19 @@ struct WrappedDeviceCallbacks { #[derive(Debug, Clone, Hash, PartialEq, Eq)] struct PlotId(String); +/// Metadata captured at plot creation time +#[derive(Debug, Clone)] +struct PlotMetadataInfo { + /// Human-readable name (e.g., "ggplot2 scatter plot 1") + name: String, + /// Kind of plot (e.g., "ggplot2 scatter plot", "histogram") + kind: String, + /// The Jupyter message ID of the execute_request that produced the plot + execution_id: String, + /// Code that produced the plot + code: String, +} + struct DeviceContext { /// Channel for sending [CommManagerEvent]s to Positron when plot events occur comm_manager_tx: Sender, @@ -171,11 +185,22 @@ struct DeviceContext { /// rendered results to the frontend. sockets: RefCell>, + /// Mapping of plot ID to its metadata (captured at creation time) + metadata: RefCell>, + + /// Counters for generating unique plot names by kind + kind_counters: RefCell>, + /// The callbacks of the wrapped device, initialized on graphics device creation wrapped_callbacks: WrappedDeviceCallbacks, /// The settings used for pre-renderings of new plots. prerender_settings: Cell, + + /// The current execution context (execution_id, code) from the active request. + /// Pushed here when an execute request starts via `on_execute_request()`, + /// cleared when the request completes. + execution_context: RefCell>, } impl DeviceContext { @@ -189,6 +214,8 @@ impl DeviceContext { should_render: Cell::new(true), id: RefCell::new(Self::new_id()), sockets: RefCell::new(HashMap::new()), + metadata: RefCell::new(HashMap::new()), + kind_counters: RefCell::new(HashMap::new()), wrapped_callbacks: WrappedDeviceCallbacks::default(), prerender_settings: Cell::new(PlotRenderSettings { size: PlotSize { @@ -198,9 +225,25 @@ impl DeviceContext { pixel_ratio: 1., format: PlotRenderFormat::Png, }), + execution_context: RefCell::new(None), } } + /// Set the current execution context (called when an execute request starts) + fn set_execution_context(&self, execution_id: String, code: String) { + *self.execution_context.borrow_mut() = Some((execution_id, code)); + } + + /// Clear the current execution context (called when an execute request completes) + fn clear_execution_context(&self) { + *self.execution_context.borrow_mut() = None; + } + + /// Get the current execution context (clones the value) + fn get_execution_context(&self) -> Option<(String, String)> { + self.execution_context.borrow().clone() + } + /// Create a new id for this new plot page (from Positron's perspective) /// and note that this is a new page fn new_positron_page(&self) { @@ -283,6 +326,57 @@ impl DeviceContext { PlotId(Uuid::new_v4().to_string()) } + /// Capture the current execution context for a new plot. + /// + /// First checks for context pushed via `on_execute_request()`, then falls back + /// to getting context from RMain's active request (for backwards compatibility + /// and edge cases). + fn capture_execution_context(&self) -> (String, String) { + // First, check if we have a stored execution context from on_execute_request() + if let Some(ctx) = self.get_execution_context() { + return ctx; + } + + // Fall back to getting context from RMain (for edge cases) + RMain::with(|main| { + main.get_execution_context().unwrap_or_else(|| { + // No active request - might be during startup or from R code + (String::new(), String::new()) + }) + }) + } + + /// Detect the kind of plot from the recording. + /// + /// Calls into R to inspect the plot recording and/or `.Last.value`. + fn detect_plot_kind(&self, id: &PlotId) -> String { + let result = RFunction::from(".ps.graphics.detect_plot_kind") + .param("id", id) + .call(); + + match result { + Ok(kind) => { + // Safety: We just called an R function that returns a string + unsafe { kind.to::() }.unwrap_or_else(|err| { + log::warn!("Failed to convert plot kind to string: {err:?}"); + "plot".to_string() + }) + }, + Err(err) => { + log::warn!("Failed to detect plot kind: {err:?}"); + "plot".to_string() + }, + } + } + + /// Generate a unique name for a plot of the given kind + fn generate_plot_name(&self, kind: &str) -> String { + let mut counters = self.kind_counters.borrow_mut(); + let counter = counters.entry(kind.to_string()).or_insert(0); + *counter += 1; + format!("{} {}", kind, counter) + } + /// Process outstanding RPC requests received from Positron /// /// At idle time we loop through our set of plot channels and check if Positron has @@ -399,6 +493,34 @@ impl DeviceContext { log::trace!("PlotBackendRequest::GetIntrinsicSize"); Ok(PlotBackendReply::GetIntrinsicSizeReply(None)) }, + PlotBackendRequest::GetMetadata => { + log::trace!("PlotBackendRequest::GetMetadata"); + + // Metadata was captured at plot creation time, just retrieve it + let stored_metadata = self.metadata.borrow(); + let info = stored_metadata.get(id); + + let plot_metadata = match info { + Some(info) => PlotMetadata { + name: info.name.clone(), + kind: info.kind.clone(), + execution_id: info.execution_id.clone(), + code: info.code.clone(), + }, + None => { + // Fallback if metadata wasn't captured (shouldn't happen) + log::warn!("No metadata found for plot id {id}"); + PlotMetadata { + name: "plot".to_string(), + kind: "plot".to_string(), + execution_id: String::new(), + code: String::new(), + } + }, + }; + + Ok(PlotBackendReply::GetMetadataReply(plot_metadata)) + }, PlotBackendRequest::Render(plot_meta) => { log::trace!("PlotBackendRequest::Render"); @@ -432,6 +554,9 @@ impl DeviceContext { // RefCell safety: Short borrows in the file self.sockets.borrow_mut().remove(id); + // Remove metadata for this plot + self.metadata.borrow_mut().remove(id); + // The plot data is stored at R level. Assumes we're called on the R // thread at idle time so there's no race issues (see // `on_process_idle_events()`). @@ -459,6 +584,10 @@ impl DeviceContext { } } + /// Process outstanding plot changes + /// + /// Uses execution context stored via `on_execute_request()` or falls back to + /// getting context from RMain's active request. #[tracing::instrument(level = "trace", skip_all)] fn process_changes(&self) { let id = self.id(); @@ -501,6 +630,19 @@ impl DeviceContext { fn process_new_plot_positron(&self, id: &PlotId) { log::trace!("Notifying Positron of new plot"); + let (execution_id, code) = self.capture_execution_context(); + let kind = self.detect_plot_kind(id); + let name = self.generate_plot_name(&kind); + + self.metadata + .borrow_mut() + .insert(id.clone(), PlotMetadataInfo { + name, + kind, + execution_id, + code, + }); + // Let Positron know that we just created a new plot. let socket = CommSocket::new( CommInitiator::BackEnd, @@ -543,6 +685,19 @@ impl DeviceContext { fn process_new_plot_jupyter_protocol(&self, id: &PlotId) { log::trace!("Notifying Jupyter frontend of new plot"); + let (execution_id, code) = self.capture_execution_context(); + let kind = self.detect_plot_kind(id); + let name = self.generate_plot_name(&kind); + + self.metadata + .borrow_mut() + .insert(id.clone(), PlotMetadataInfo { + name, + kind, + execution_id, + code, + }); + let data = unwrap!(self.create_display_data_plot(id), Err(error) => { log::error!("Failed to create plot due to: {error}."); return; @@ -794,6 +949,19 @@ pub(crate) fn on_process_idle_events() { DEVICE_CONTEXT.with_borrow(|cell| cell.process_rpc_requests()); } +/// Hook applied when an execute request starts +/// +/// Pushes the execution context (execution_id, code) to the graphics device +/// so it can be captured when new plots are created. This allows plots to be +/// correctly attributed to the code that generated them. +/// +/// Called from `handle_execute_request()` after setting the active request. +#[tracing::instrument(level = "trace", skip_all)] +pub(crate) fn on_execute_request(execution_id: String, code: String) { + log::trace!("Entering on_execute_request"); + DEVICE_CONTEXT.with_borrow(|cell| cell.set_execution_context(execution_id, code)); +} + /// Hook applied after a code chunk has finished executing /// /// Not an official graphics device hook, instead we run this manually after @@ -814,7 +982,10 @@ pub(crate) fn on_process_idle_events() { #[tracing::instrument(level = "trace", skip_all)] pub(crate) fn on_did_execute_request() { log::trace!("Entering on_did_execute_request"); - DEVICE_CONTEXT.with_borrow(|cell| cell.process_changes()); + DEVICE_CONTEXT.with_borrow(|cell| { + cell.process_changes(); + cell.clear_execution_context(); + }); } /// Activation callback