Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Package: ggRandomForests
Type: Package
Title: Visually Exploring Random Forests
Version: 2.4.0
Version: 2.4.1
Date: 2025-06-17
Authors@R: person("John", "Ehrlinger",
role = c("aut", "cre"),
Expand Down
4 changes: 3 additions & 1 deletion NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ export(kaplan)
export(nelson)
export(quantile_pts)
export(r_data_types)
export(varpro_feature_name)
export(surv_partial.rfsrc)
export(varpro_feature_names)
importFrom(dplyr,across)
importFrom(dplyr,mutate)
importFrom(dplyr,n_distinct)
Expand All @@ -49,6 +50,7 @@ importFrom(ggplot2,labs)
importFrom(ggplot2,theme)
importFrom(parallel,mclapply)
importFrom(randomForest,randomForest)
importFrom(randomForestSRC,partial.rfsrc)
importFrom(randomForestSRC,vimp)
importFrom(stats,median)
importFrom(stats,na.omit)
Expand Down
123 changes: 123 additions & 0 deletions R/surv_partial.rfsrc.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
#' Calculate survival curve partial plot.
#'
#' @param rforest the randomForestSrc object
#' @param var_list a list of variables of interest. These variables should be a
#' subset of rforest$xvar.names
#' @param npts the number of points to segment the xvar of interest
#' @param partial.type the return prediction type.
#' For survival forests: type c("surv", "mort", "chf")
#' For competing risk forests: type c("years.lost", "cif", "chf")
#' see \code{randomForestSRC::partial.rfsrc} or more information
#'
#' @importFrom randomForestSRC partial.rfsrc
#' @examples
#' ## ------------------------------------------------------------
#' ## survival
#' ## ------------------------------------------------------------
#'
#' data(veteran, package = "randomForestSRC")
#' v.obj <- randomForestSRC::rfsrc(Surv(time,status)~.,
#' veteran, nsplit = 10, ntree = 100)
#'
#' spart <- surv_partial.rfsrc(v.obj, var_list="age", partial.type = "mort")
#'
#' ## partial effect of age on mortality
#' partial.obj <- partial(v.obj,
#' partial.type = "mort",
#' partial.xvar = "age",
#' partial.values = v.obj$xvar$age,
#' partial.time = v.obj$time.interest)
#' pdta <- get.partial.plot.data(partial.obj)
#'
#' plot(lowess(pdta$x, pdta$yhat, f = 1/3),
#' type = "l", xlab = "age", ylab = "adjusted mortality")
#'
#' ## example where x is discrete - partial effect of age on mortality
#' ## we use the granule=TRUE option
#' partial.obj <- partial(v.obj,
#' partial.type = "mort",
#' partial.xvar = "trt",
#' partial.values = v.obj$xvar$trt,
#' partial.time = v.obj$time.interest)
#' pdta <- get.partial.plot.data(partial.obj, granule = TRUE)
#' boxplot(pdta$yhat ~ pdta$x, xlab = "treatment", ylab = "partial effect")
#'
#'
#' ## partial effects of karnofsky score on survival
#' karno <- quantile(v.obj$xvar$karno)
#' partial.obj <- partial(v.obj,
#' partial.type = "surv",
#' partial.xvar = "karno",
#' partial.values = karno,
#' partial.time = v.obj$time.interest)
#' pdta <- get.partial.plot.data(partial.obj)
#'
#' matplot(pdta$partial.time, t(pdta$yhat), type = "l", lty = 1,
#' xlab = "time", ylab = "karnofsky adjusted survival")
#' legend("topright", legend = paste0("karnofsky = ", karno), fill = 1:5)
#'
#'
#' ## ------------------------------------------------------------
#' ## competing risk
#' ## ------------------------------------------------------------
#'
#' data(follic, package = "randomForestSRC")
#' follic.obj <- rfsrc(Surv(time, status) ~ ., follic, nsplit = 3, ntree = 100)
#'
#' ## partial effect of age on years lost
#' partial.obj <- partial(follic.obj,
#' partial.type = "years.lost",
#' partial.xvar = "age",
#' partial.values = follic.obj$xvar$age,
#' partial.time = follic.obj$time.interest)
#' pdta1 <- get.partial.plot.data(partial.obj, target = 1)
#' pdta2 <- get.partial.plot.data(partial.obj, target = 2)
#'
#' par(mfrow=c(2,2))
#' plot(lowess(pdta1$x, pdta1$yhat),
#' type = "l", xlab = "age", ylab = "adjusted years lost relapse")
#' plot(lowess(pdta2$x, pdta2$yhat),
#' type = "l", xlab = "age", ylab = "adjusted years lost death")
#'
#' ## partial effect of age on cif
#' partial.obj <- partial(follic.obj,
#' partial.type = "cif",
#' partial.xvar = "age",
#' partial.values = quantile(follic.obj$xvar$age),
#' partial.time = follic.obj$time.interest)
#' pdta1 <- get.partial.plot.data(partial.obj, target = 1)
#' pdta2 <- get.partial.plot.data(partial.obj, target = 2)
#'
#' matplot(pdta1$partial.time, t(pdta1$yhat), type = "l", lty = 1,
#' xlab = "time", ylab = "age adjusted cif for relapse")
#' matplot(pdta2$partial.time, t(pdta2$yhat), type = "l", lty = 1,
#' xlab = "time", ylab = "age adjusted cif for death")
#'
#' @export surv_partial.rfsrc
surv_partial.rfsrc <- function(rforest, var_list, npts=25, partial.type = "surv") {
###----------Partial dependency estimation, for each variable, at each time point ----
surv.lst <- lapply(var_list, function(xvar) {
## extract the key variable
cat("partial plot for:", xvar, "\n")

## determine the partial plot data
xv <- sort(unique(rforest$xvar[, xvar]))
xv <- unique(xv[seq(1, length(xv), length = npts)])

## Get the partial.plot.data
partial.dta <- randomForestSRC::get.partial.plot.data(
randomForestSRC::partial.rfsrc(
rforest,
partial.type = partial.type,
partial.xvar = xvar,
partial.values = xv,
partial.time = rforest$time.interest
)
)

list(name=xvar,
dta = partial.dta)

})
return(surv.lst)
}
2 changes: 1 addition & 1 deletion R/varpro_feature_names.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
#'
#' @importFrom stringr str_sub
#' @export
varpro_feature_name <- function(varpro_names, dataset) {
varpro_feature_names <- function(varpro_names, dataset) {
inc_set <- varpro_names[which(varpro_names %in% colnames(dataset))]
one_set <- varpro_names[which(!varpro_names %in% colnames(dataset))]
while (length(one_set) > 0) {
Expand Down
108 changes: 108 additions & 0 deletions man/surv_partial.rfsrc.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions man/varpro_feature_name.Rd → man/varpro_feature_names.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading