diff --git a/DESCRIPTION b/DESCRIPTION index 3280baa0..2b422df5 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -48,10 +48,12 @@ Imports: rlang (>= 0.4.7) Suggests: bayesplot, + cli, fs, ggplot2, knitr (>= 1.37), loo (>= 2.0.0), + progressr, qs2, rmarkdown, testthat (>= 2.1.0), diff --git a/NAMESPACE b/NAMESPACE index b157025d..9e12b908 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -41,6 +41,7 @@ export(print_example_program) export(read_cmdstan_csv) export(read_sample_csv) export(rebuild_cmdstan) +export(register_default_progress_handler) export(register_knitr_engine) export(set_cmdstan_path) export(set_num_threads) diff --git a/R/model.R b/R/model.R index bb1427ed..5a91214b 100644 --- a/R/model.R +++ b/R/model.R @@ -1116,6 +1116,14 @@ CmdStanModel$set("public", name = "format", value = format) #' #' @template model-common-args #' @template model-sample-args +#' @param show_progress_bar (logical). If TRUE, registers a progress bar to +#' display sampling progress via the `progressr` framework. The user is +#' responsible for registering a handler to display the progress bar. A +#' default handler, using the `cli` package, can be registered via the +#' `cmdstanr::register_default_progress_handler()`. Default: FALSE. +#' @param suppress_iteration_messages Suppress CmdStan output lines reporting +#' iterations, intended for use with the `show_progress_bar` argument. Defaults +#' to the value of `show_progress_bar`. #' @param cores,num_cores,num_chains,num_warmup,num_samples,save_extra_diagnostics,max_depth,stepsize,validate_csv #' Deprecated and will be removed in a future release. #' @@ -1157,6 +1165,8 @@ sample <- function(data = NULL, diagnostics = c("divergences", "treedepth", "ebfmi"), save_metric = NULL, save_cmdstan_config = NULL, + show_progress_bar = FALSE, + suppress_iteration_messages = NULL, # deprecated cores = NULL, num_cores = NULL, @@ -1221,12 +1231,45 @@ sample <- function(data = NULL, if (fixed_param) { save_warmup <- FALSE } + # Check for and create progressr::progressor object for progress reporting, if required. + # Pass default value for refresh + progress_bar <- NULL + if (show_progress_bar) { + if(requireNamespace("progressr", quietly = TRUE)) { + + # progressr only supports single-line progress bars at time of writing, + # so all chains must be combined into a single process bar. + + # Calculate a total number of steps for progress as + # (chains*(iter_warmup+iter_sampling)). + # We will update the progress bar by 'refresh' steps each time. + + # As all the arguments to CmdStan can be NULL, we need to reproduce the + # defaults here manually. + + n_samples <- ifelse(is.null(iter_sampling), 1000, iter_sampling) + n_warmup <- ifelse(is.null(iter_warmup), 1000, iter_warmup) + n_chains <- ifelse(is.null(chains), 1, chains) + n_steps <- (n_chains*(n_samples+n_warmup)) + + progress_bar <- progressr::progressor(steps=n_steps, auto_finish=TRUE) + + } + else { + warning("'show_progress_bar=TRUE' requires the 'progressr' package. Please install 'progressr'.") + } + } procs <- CmdStanMCMCProcs$new( num_procs = checkmate::assert_integerish(chains, lower = 1, len = 1), + iter_warmup = checkmate::assert_integerish(iter_warmup, lower = 0, len = 1, null.ok = TRUE), + iter_sampling = checkmate::assert_integerish(iter_sampling, lower = 0, len = 1, null.ok = TRUE), parallel_procs = checkmate::assert_integerish(parallel_chains, lower = 1, null.ok = TRUE), threads_per_proc = assert_valid_threads(threads_per_chain, self$cpp_options(), multiple_chains = TRUE), show_stderr_messages = show_exceptions, - show_stdout_messages = show_messages + show_stdout_messages = show_messages, + progress_bar = progress_bar, + suppress_iteration_messages = suppress_iteration_messages, + refresh = refresh ) model_variables <- NULL if (is_variables_method_supported(self)) { @@ -2375,4 +2418,4 @@ resolve_exe_path <- function( exe <- self_exe_file } exe -} \ No newline at end of file +} diff --git a/R/run.R b/R/run.R index 9202f0d4..2bbb0482 100644 --- a/R/run.R +++ b/R/run.R @@ -511,6 +511,10 @@ check_target_exe <- function(exe) { } procs$check_finished() } + # Ensure at this point that any created progress bar is closed. + if(!is.null(private$progress_bar_)){ + private$progress_bar_(type="finish") + } procs$set_total_time(as.double((Sys.time() - start_time), units = "secs")) procs$report_time() } @@ -702,14 +706,32 @@ CmdStanProcs <- R6::R6Class( classname = "CmdStanProcs", public = list( initialize = function(num_procs, + iter_warmup = NULL, + iter_sampling = NULL, parallel_procs = NULL, threads_per_proc = NULL, show_stderr_messages = TRUE, - show_stdout_messages = TRUE) { + show_stdout_messages = TRUE, + progress_bar = NULL, + suppress_iteration_messages = NULL, + refresh = NULL ) { checkmate::assert_integerish(num_procs, lower = 1, len = 1, any.missing = FALSE) + checkmate::assert_integerish(iter_warmup, lower = 0, len = 1, any.missing = FALSE, null.ok = TRUE ) + checkmate::assert_integerish(iter_sampling, lower = 0, len = 1, any.missing = FALSE, null.ok = TRUE ) checkmate::assert_integerish(parallel_procs, lower = 1, len = 1, any.missing = FALSE, null.ok = TRUE) checkmate::assert_integerish(threads_per_proc, lower = 1, len = 1, null.ok = TRUE) private$num_procs_ <- as.integer(num_procs) + if (is.null(iter_warmup)) { + private$iter_warmup_ <- 1000 + } else { + private$iter_warmup_ <- as.integer(iter_warmup) + } + if (is.null(iter_sampling)) { + private$iter_sampling_ <- 1000 + } else { + private$iter_sampling_ <- as.integer(iter_sampling) + } + if (is.null(parallel_procs)) { private$parallel_procs_ <- private$num_procs_ } else { @@ -726,6 +748,27 @@ CmdStanProcs <- R6::R6Class( private$proc_total_time_ <- zeros private$show_stderr_messages_ <- show_stderr_messages private$show_stdout_messages_ <- show_stdout_messages + private$progress_bar_ <- progress_bar + + # Defaults when enabling the progress bar: + # - If 'progress_bar' is set, suppress iteration messages; + # - if `progress_bar` is unset, do not suppress iteration messages; + # - if 'suppress_iteration_messages' is set explicitly, honour that setting. + if(is.null(progress_bar)) { + private$suppress_iteration_messages_ <- FALSE + } else { + private$suppress_iteration_messages_ <- TRUE + } + if(!is.null(suppress_iteration_messages)) { + private$suppress_iteration_messages_ <- suppress_iteration_messages + } + + if(is.null(refresh)) { + # Default to Stan default of 100 if refresh not set explicitly. + private$refresh_ <- 100 + } else { + private$refresh_ <- refresh + } invisible(self) }, show_stdout_messages = function () { @@ -734,9 +777,24 @@ CmdStanProcs <- R6::R6Class( show_stderr_messages = function () { private$show_stderr_messages_ }, + progress_bar = function() { + private$progress_bar_ + }, + suppress_iteration_messages = function () { + private$suppress_iteration_messages_ + }, + refresh = function () { + private$refresh_ + }, num_procs = function() { private$num_procs_ }, + iter_warmup = function() { + privatea$iter_warmup_ + }, + iter_sampling = function() { + private$iter_sampling_ + }, parallel_procs = function() { private$parallel_procs_ }, @@ -962,6 +1020,8 @@ CmdStanProcs <- R6::R6Class( processes_ = NULL, # will be list of processx::process objects proc_ids_ = integer(), num_procs_ = integer(), + iter_warmup_ = integer(), + iter_sampling_ = integer(), parallel_procs_ = integer(), active_procs_ = integer(), threads_per_proc_ = integer(), @@ -973,7 +1033,10 @@ CmdStanProcs <- R6::R6Class( proc_error_ouput_ = list(), total_time_ = numeric(), show_stderr_messages_ = TRUE, - show_stdout_messages_ = TRUE + show_stdout_messages_ = TRUE, + progress_bar_ = NULL, + suppress_iteration_messages_ = NULL, + refresh_ = 100 ) ) @@ -1050,6 +1113,53 @@ CmdStanMCMCProcs <- R6::R6Class( || grepl("stancflags", line, fixed = TRUE)) { ignore_line <- TRUE } + # Update progress bar + if (!ignore_line && !is.null(private$progress_bar_)) { + # Pass the current output line to the progress bar as a message, + # but only update the progress bar if the current line is an + # iteration message. + progress_amount <- 0 + if(grepl("Iteration:", line, perl = TRUE)) { + # Calculating the amount by which to increment the progress bar + # is more complicated than it initially seems, due to occasional + # extra or awkward iteration reporting messages when starting + # sampling, moving from warmup to sampling, reaching the end of + # sampling where the number of samples is not a multiple of the + # refresh_rate. + + # Strategy: + # If the line's iteration value is divisible by refresh_rate, or + # is the final sampling step, update the progress bar by + # refresh_rate. + + # Additionally, when moving from warmup to sampling, iterations + # are reported starting from a baseline of the number of warmup + # iterations. (For example, if refresh_rate is 12 and iter_warmup + # is 100, the first reported iteration for sampling will be 112, + # not 108.) + + # Get the current iteration count. + # Subtract iter_warmup if greater than that. + iter_current <- as.numeric(gsub( ".*Iteration:\\s*([0-9]+) \\/.*", "\\1", line, perl=TRUE )) + if( iter_current > private$iter_warmup_ ) { + iter_current <- iter_current - private$iter_warmup_ + } + + # Update progress bar if the iteration is a multiple of the + # refresh rate, or is the final sampling iteration. + if(((iter_current %% private$refresh_) == 0) | + iter_current == private$iter_warmup_ + private$iter_sampling_) { + progress_amount <- private$refresh_ + } + } + private$progress_bar_(amount=progress_amount, message=line) + } + # Allow suppression of iteration messages + if (private$suppress_iteration_messages_) { + if(grepl("Iteration:", line, perl = TRUE)) { + ignore_line <- TRUE + } + } if ((state > 1.5 && state < 5 && !ignore_line && private$show_stdout_messages_) || is_verbose_mode()) { if (state == 2) { message("Chain ", id, " ", line) diff --git a/R/utils.R b/R/utils.R index 8abc327f..8c4d89c2 100644 --- a/R/utils.R +++ b/R/utils.R @@ -1072,3 +1072,40 @@ expose_stan_functions <- function(function_env, global = FALSE, verbose = FALSE) } invisible(NULL) } + +#' Register a default progress bar handler for sampling +#' +#' Create a default progress bar for CmdStan sampling operations, and register +#' it as the default global handler for progressr updates. Requires `progressr` +#' for the progress framework, and `cli` for the default progress bar handler. +#' +#' @export +#' +#' @param verbose (logical) Report creation of progress bar to stdout? +#' The default is `TRUE`. +#' +register_default_progress_handler <- function(verbose=TRUE) { + # Require both the progressr and cli packages. + if(requireNamespace("progressr", quietly = TRUE) && requireNamespace("cli", quietly = TRUE)) { + + progressr::handlers(global=TRUE) + progressr::handlers("cli") + + # Progress bar options + options(cli.spinner = "moon", + cli.progress_show_after = 0, + cli.progress_clear = FALSE ) + + # Default informative progress output for sampling + progressr::handlers(progressr::handler_cli( + format = "{cli::pb_spin} Progress: |{cli::pb_bar}| {cli::pb_current}/{cli::pb_total} | {cli::pb_percent} | ETA: {cli::pb_eta}", + clear = FALSE + )) + if(verbose) { + message("Default progress bar registered.") + } + } else { + warning("The 'progressr' library is required to enable a progress bar. The default progress bar uses the 'cli' library.") + } + invisible(NULL) +} diff --git a/man/model-method-sample.Rd b/man/model-method-sample.Rd index 2558e630..7bf069d6 100644 --- a/man/model-method-sample.Rd +++ b/man/model-method-sample.Rd @@ -39,6 +39,8 @@ sample( diagnostics = c("divergences", "treedepth", "ebfmi"), save_metric = NULL, save_cmdstan_config = NULL, + show_progress_bar = FALSE, + suppress_iteration_messages = NULL, cores = NULL, num_cores = NULL, num_chains = NULL, @@ -303,6 +305,16 @@ with argument \code{"output save_config=1"} to save a json file which contains the argument tree and extra information (equivalent to the output CSV file header). This option is only available in CmdStan 2.34.0 and later.} +\item{show_progress_bar}{(logical). If TRUE, registers a progress bar to +display sampling progress via the \code{progressr} framework. The user is +responsible for registering a handler to display the progress bar. A +default handler, using the \code{cli} package, can be registered via the +\code{cmdstanr::register_default_progress_handler()}. Default: FALSE.} + +\item{suppress_iteration_messages}{Suppress CmdStan output lines reporting +iterations, intended for use with the \code{show_progress_bar} argument. Defaults +to the value of \code{show_progress_bar}.} + \item{cores, num_cores, num_chains, num_warmup, num_samples, save_extra_diagnostics, max_depth, stepsize, validate_csv}{Deprecated and will be removed in a future release.} } \value{ diff --git a/man/register_default_progress_handler.Rd b/man/register_default_progress_handler.Rd new file mode 100644 index 00000000..28709ade --- /dev/null +++ b/man/register_default_progress_handler.Rd @@ -0,0 +1,17 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/utils.R +\name{register_default_progress_handler} +\alias{register_default_progress_handler} +\title{Register a default progress bar handler for sampling} +\usage{ +register_default_progress_handler(verbose = TRUE) +} +\arguments{ +\item{verbose}{(logical) Report creation of progress bar to stdout? +The default is \code{TRUE}.} +} +\description{ +Create a default progress bar for CmdStan sampling operations, and register +it as the default global handler for progressr updates. Requires \code{progressr} +for the progress framework, and \code{cli} for the default progress bar handler. +}