Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,12 @@ Imports:
rlang (>= 0.4.7)
Suggests:
bayesplot,
cli,
fs,
ggplot2,
knitr (>= 1.37),
loo (>= 2.0.0),
progressr,
qs2,
rmarkdown,
testthat (>= 2.1.0),
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ export(print_example_program)
export(read_cmdstan_csv)
export(read_sample_csv)
export(rebuild_cmdstan)
export(register_default_progress_handler)
export(register_knitr_engine)
export(set_cmdstan_path)
export(set_num_threads)
Expand Down
47 changes: 45 additions & 2 deletions R/model.R
Original file line number Diff line number Diff line change
Expand Up @@ -1116,6 +1116,14 @@ CmdStanModel$set("public", name = "format", value = format)
#'
#' @template model-common-args
#' @template model-sample-args
#' @param show_progress_bar (logical). If TRUE, registers a progress bar to
#' display sampling progress via the `progressr` framework. The user is
#' responsible for registering a handler to display the progress bar. A
#' default handler, using the `cli` package, can be registered via the
#' `cmdstanr::register_default_progress_handler()`. Default: FALSE.
#' @param suppress_iteration_messages Suppress CmdStan output lines reporting
#' iterations, intended for use with the `show_progress_bar` argument. Defaults
#' to the value of `show_progress_bar`.
#' @param cores,num_cores,num_chains,num_warmup,num_samples,save_extra_diagnostics,max_depth,stepsize,validate_csv
#' Deprecated and will be removed in a future release.
#'
Expand Down Expand Up @@ -1157,6 +1165,8 @@ sample <- function(data = NULL,
diagnostics = c("divergences", "treedepth", "ebfmi"),
save_metric = NULL,
save_cmdstan_config = NULL,
show_progress_bar = FALSE,
suppress_iteration_messages = NULL,
# deprecated
cores = NULL,
num_cores = NULL,
Expand Down Expand Up @@ -1221,12 +1231,45 @@ sample <- function(data = NULL,
if (fixed_param) {
save_warmup <- FALSE
}
# Check for and create progressr::progressor object for progress reporting, if required.
# Pass default value for refresh
progress_bar <- NULL
if (show_progress_bar) {
if(requireNamespace("progressr", quietly = TRUE)) {

# progressr only supports single-line progress bars at time of writing,
# so all chains must be combined into a single process bar.

# Calculate a total number of steps for progress as
# (chains*(iter_warmup+iter_sampling)).
# We will update the progress bar by 'refresh' steps each time.

# As all the arguments to CmdStan can be NULL, we need to reproduce the
# defaults here manually.

n_samples <- ifelse(is.null(iter_sampling), 1000, iter_sampling)
n_warmup <- ifelse(is.null(iter_warmup), 1000, iter_warmup)
n_chains <- ifelse(is.null(chains), 1, chains)
n_steps <- (n_chains*(n_samples+n_warmup))

progress_bar <- progressr::progressor(steps=n_steps, auto_finish=TRUE)

}
else {
warning("'show_progress_bar=TRUE' requires the 'progressr' package. Please install 'progressr'.")
}
}
procs <- CmdStanMCMCProcs$new(
num_procs = checkmate::assert_integerish(chains, lower = 1, len = 1),
iter_warmup = checkmate::assert_integerish(iter_warmup, lower = 0, len = 1, null.ok = TRUE),
iter_sampling = checkmate::assert_integerish(iter_sampling, lower = 0, len = 1, null.ok = TRUE),
parallel_procs = checkmate::assert_integerish(parallel_chains, lower = 1, null.ok = TRUE),
threads_per_proc = assert_valid_threads(threads_per_chain, self$cpp_options(), multiple_chains = TRUE),
show_stderr_messages = show_exceptions,
show_stdout_messages = show_messages
show_stdout_messages = show_messages,
progress_bar = progress_bar,
suppress_iteration_messages = suppress_iteration_messages,
refresh = refresh
)
model_variables <- NULL
if (is_variables_method_supported(self)) {
Expand Down Expand Up @@ -2375,4 +2418,4 @@ resolve_exe_path <- function(
exe <- self_exe_file
}
exe
}
}
114 changes: 112 additions & 2 deletions R/run.R
Original file line number Diff line number Diff line change
Expand Up @@ -511,6 +511,10 @@ check_target_exe <- function(exe) {
}
procs$check_finished()
}
# Ensure at this point that any created progress bar is closed.
if(!is.null(private$progress_bar_)){
private$progress_bar_(type="finish")
}
procs$set_total_time(as.double((Sys.time() - start_time), units = "secs"))
procs$report_time()
}
Expand Down Expand Up @@ -702,14 +706,32 @@ CmdStanProcs <- R6::R6Class(
classname = "CmdStanProcs",
public = list(
initialize = function(num_procs,
iter_warmup = NULL,
iter_sampling = NULL,
parallel_procs = NULL,
threads_per_proc = NULL,
show_stderr_messages = TRUE,
show_stdout_messages = TRUE) {
show_stdout_messages = TRUE,
progress_bar = NULL,
suppress_iteration_messages = NULL,
refresh = NULL ) {
checkmate::assert_integerish(num_procs, lower = 1, len = 1, any.missing = FALSE)
checkmate::assert_integerish(iter_warmup, lower = 0, len = 1, any.missing = FALSE, null.ok = TRUE )
checkmate::assert_integerish(iter_sampling, lower = 0, len = 1, any.missing = FALSE, null.ok = TRUE )
checkmate::assert_integerish(parallel_procs, lower = 1, len = 1, any.missing = FALSE, null.ok = TRUE)
checkmate::assert_integerish(threads_per_proc, lower = 1, len = 1, null.ok = TRUE)
private$num_procs_ <- as.integer(num_procs)
if (is.null(iter_warmup)) {
private$iter_warmup_ <- 1000
} else {
private$iter_warmup_ <- as.integer(iter_warmup)
}
if (is.null(iter_sampling)) {
private$iter_sampling_ <- 1000
} else {
private$iter_sampling_ <- as.integer(iter_sampling)
}

if (is.null(parallel_procs)) {
private$parallel_procs_ <- private$num_procs_
} else {
Expand All @@ -726,6 +748,27 @@ CmdStanProcs <- R6::R6Class(
private$proc_total_time_ <- zeros
private$show_stderr_messages_ <- show_stderr_messages
private$show_stdout_messages_ <- show_stdout_messages
private$progress_bar_ <- progress_bar

# Defaults when enabling the progress bar:
# - If 'progress_bar' is set, suppress iteration messages;
# - if `progress_bar` is unset, do not suppress iteration messages;
# - if 'suppress_iteration_messages' is set explicitly, honour that setting.
if(is.null(progress_bar)) {
private$suppress_iteration_messages_ <- FALSE
} else {
private$suppress_iteration_messages_ <- TRUE
}
if(!is.null(suppress_iteration_messages)) {
private$suppress_iteration_messages_ <- suppress_iteration_messages
}

if(is.null(refresh)) {
# Default to Stan default of 100 if refresh not set explicitly.
private$refresh_ <- 100
} else {
private$refresh_ <- refresh
}
invisible(self)
},
show_stdout_messages = function () {
Expand All @@ -734,9 +777,24 @@ CmdStanProcs <- R6::R6Class(
show_stderr_messages = function () {
private$show_stderr_messages_
},
progress_bar = function() {
private$progress_bar_
},
suppress_iteration_messages = function () {
private$suppress_iteration_messages_
},
refresh = function () {
private$refresh_
},
num_procs = function() {
private$num_procs_
},
iter_warmup = function() {
privatea$iter_warmup_
},
iter_sampling = function() {
private$iter_sampling_
},
parallel_procs = function() {
private$parallel_procs_
},
Expand Down Expand Up @@ -962,6 +1020,8 @@ CmdStanProcs <- R6::R6Class(
processes_ = NULL, # will be list of processx::process objects
proc_ids_ = integer(),
num_procs_ = integer(),
iter_warmup_ = integer(),
iter_sampling_ = integer(),
parallel_procs_ = integer(),
active_procs_ = integer(),
threads_per_proc_ = integer(),
Expand All @@ -973,7 +1033,10 @@ CmdStanProcs <- R6::R6Class(
proc_error_ouput_ = list(),
total_time_ = numeric(),
show_stderr_messages_ = TRUE,
show_stdout_messages_ = TRUE
show_stdout_messages_ = TRUE,
progress_bar_ = NULL,
suppress_iteration_messages_ = NULL,
refresh_ = 100
)
)

Expand Down Expand Up @@ -1050,6 +1113,53 @@ CmdStanMCMCProcs <- R6::R6Class(
|| grepl("stancflags", line, fixed = TRUE)) {
ignore_line <- TRUE
}
# Update progress bar
if (!ignore_line && !is.null(private$progress_bar_)) {
# Pass the current output line to the progress bar as a message,
# but only update the progress bar if the current line is an
# iteration message.
progress_amount <- 0
if(grepl("Iteration:", line, perl = TRUE)) {
# Calculating the amount by which to increment the progress bar
# is more complicated than it initially seems, due to occasional
# extra or awkward iteration reporting messages when starting
# sampling, moving from warmup to sampling, reaching the end of
# sampling where the number of samples is not a multiple of the
# refresh_rate.

# Strategy:
# If the line's iteration value is divisible by refresh_rate, or
# is the final sampling step, update the progress bar by
# refresh_rate.

# Additionally, when moving from warmup to sampling, iterations
# are reported starting from a baseline of the number of warmup
# iterations. (For example, if refresh_rate is 12 and iter_warmup
# is 100, the first reported iteration for sampling will be 112,
# not 108.)

# Get the current iteration count.
# Subtract iter_warmup if greater than that.
iter_current <- as.numeric(gsub( ".*Iteration:\\s*([0-9]+) \\/.*", "\\1", line, perl=TRUE ))
if( iter_current > private$iter_warmup_ ) {
iter_current <- iter_current - private$iter_warmup_
}

# Update progress bar if the iteration is a multiple of the
# refresh rate, or is the final sampling iteration.
if(((iter_current %% private$refresh_) == 0) |
iter_current == private$iter_warmup_ + private$iter_sampling_) {
progress_amount <- private$refresh_
}
}
private$progress_bar_(amount=progress_amount, message=line)
}
# Allow suppression of iteration messages
if (private$suppress_iteration_messages_) {
if(grepl("Iteration:", line, perl = TRUE)) {
ignore_line <- TRUE
}
}
if ((state > 1.5 && state < 5 && !ignore_line && private$show_stdout_messages_) || is_verbose_mode()) {
if (state == 2) {
message("Chain ", id, " ", line)
Expand Down
37 changes: 37 additions & 0 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -1072,3 +1072,40 @@ expose_stan_functions <- function(function_env, global = FALSE, verbose = FALSE)
}
invisible(NULL)
}

#' Register a default progress bar handler for sampling
#'
#' Create a default progress bar for CmdStan sampling operations, and register
#' it as the default global handler for progressr updates. Requires `progressr`
#' for the progress framework, and `cli` for the default progress bar handler.
#'
#' @export
#'
#' @param verbose (logical) Report creation of progress bar to stdout?
#' The default is `TRUE`.
#'
register_default_progress_handler <- function(verbose=TRUE) {
# Require both the progressr and cli packages.
if(requireNamespace("progressr", quietly = TRUE) && requireNamespace("cli", quietly = TRUE)) {

progressr::handlers(global=TRUE)
progressr::handlers("cli")

# Progress bar options
options(cli.spinner = "moon",
cli.progress_show_after = 0,
cli.progress_clear = FALSE )

# Default informative progress output for sampling
progressr::handlers(progressr::handler_cli(
format = "{cli::pb_spin} Progress: |{cli::pb_bar}| {cli::pb_current}/{cli::pb_total} | {cli::pb_percent} | ETA: {cli::pb_eta}",
clear = FALSE
))
if(verbose) {
message("Default progress bar registered.")
}
} else {
warning("The 'progressr' library is required to enable a progress bar. The default progress bar uses the 'cli' library.")
}
invisible(NULL)
}
12 changes: 12 additions & 0 deletions man/model-method-sample.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

17 changes: 17 additions & 0 deletions man/register_default_progress_handler.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading