diff --git a/DESCRIPTION b/DESCRIPTION index 5f5f607..7dcb6b5 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,7 +1,7 @@ Type: Package Package: sparsegl Title: Sparse Group Lasso -Version: 1.1.1.9001 +Version: 1.1.1.9002 Authors@R: c( person("Daniel J.", "McDonald", , "daniel@stat.ubc.ca", role = c("aut", "cre", "cph")), person("Xiaoxuan", "Liang", , "xiaoxuan.liang@stat.ubc.ca", role = "aut"), diff --git a/NEWS.md b/NEWS.md index a57c5d3..af376a9 100644 --- a/NEWS.md +++ b/NEWS.md @@ -3,6 +3,7 @@ * Force `weights` to sum to `nobs` for all IRWLS cases. * Remove `magrittr` from imports * Add `auc` option for CV and binomial +* Address #59 (@kaichen) # sparsegl 1.1.1 diff --git a/R/sparsegl.R b/R/sparsegl.R index f60d3dc..6327b44 100644 --- a/R/sparsegl.R +++ b/R/sparsegl.R @@ -140,31 +140,50 @@ #' yp <- rpois(n, abs(X %*% beta_star)) #' fit_pois <- sparsegl(X, yp, group = groups, family = poisson()) sparsegl <- function( - x, y, group = NULL, family = c("gaussian", "binomial"), - nlambda = 100, lambda.factor = ifelse(nobs < nvars, 0.01, 1e-04), - lambda = NULL, pf_group = sqrt(bs), pf_sparse = rep(1, nvars), - intercept = TRUE, asparse = 0.05, standardize = TRUE, - lower_bnd = -Inf, upper_bnd = Inf, - weights = NULL, offset = NULL, warm = NULL, - trace_it = 0, - dfmax = as.integer(max(group)) + 1L, - pmax = min(dfmax * 1.2, as.integer(max(group))), - eps = 1e-08, maxit = 3e+06) { + x, + y, + group = NULL, + family = c("gaussian", "binomial"), + nlambda = 100, + lambda.factor = ifelse(nobs < nvars, 0.01, 1e-04), + lambda = NULL, + pf_group = sqrt(bs), + pf_sparse = rep(1, nvars), + intercept = TRUE, + asparse = 0.05, + standardize = TRUE, + lower_bnd = -Inf, + upper_bnd = Inf, + weights = NULL, + offset = NULL, + warm = NULL, + trace_it = 0, + dfmax = as.integer(max(group)) + 1L, + pmax = min(dfmax * 1.2, as.integer(max(group))), + eps = 1e-08, + maxit = 3e+06 +) { this.call <- match.call() if (!is.matrix(x) && !inherits(x, "sparseMatrix")) { cli_abort("`x` must be a matrix.") } - if (any(is.na(x))) cli_abort("Missing values in `x` are not supported.") + if (any(is.na(x))) { + cli_abort("Missing values in `x` are not supported.") + } y <- drop(y) - if (!is.null(dim(y))) cli_abort("`y` must be a vector or 1-column matrix.") + if (!is.null(dim(y))) { + cli_abort("`y` must be a vector or 1-column matrix.") + } np <- dim(x) nobs <- as.integer(np[1]) nvars <- as.integer(np[2]) vnames <- colnames(x) - if (is.null(vnames)) vnames <- paste("V", seq(nvars), sep = "") + if (is.null(vnames)) { + vnames <- paste("V", seq(nvars), sep = "") + } if (length(y) != nobs) { cli_abort("`x` has {nobs} rows while `y` has {length(y)}.") @@ -185,7 +204,10 @@ sparsegl <- function( bn <- as.integer(max(group)) # number of groups bs <- as.integer(as.numeric(table(group))) # number of elements in each group - if (!identical(as.integer(sort(unique(group))), as.integer(1:bn))) { + if (is.unsorted(group)) { + cli_abort("`group` must be sorted in increasing order.") + } + if (!identical(as.integer(unique(group)), 1:bn)) { cli_abort("Groups must be consecutively numbered 1, 2, 3, ...") } @@ -198,15 +220,21 @@ sparsegl <- function( if (asparse < 0) { asparse <- 0 - cli_warn("`asparse` must be in {.val [0, 1]}, running ordinary group lasso.") + cli_warn( + "`asparse` must be in {.val [0, 1]}, running ordinary group lasso." + ) + } + if (any(pf_sparse < 0)) { + cli::cli_abort("`pf_sparse` must be non-negative.") } - if (any(pf_sparse < 0)) cli::cli_abort("`pf_sparse` must be non-negative.") if (any(is.infinite(pf_sparse))) { cli_abort( "`pf_sparse` may not be infinite. Simply remove the column from `x`." ) } - if (any(pf_group < 0)) cli_abort("`pf_group` must be non-negative.") + if (any(pf_group < 0)) { + cli_abort("`pf_group` must be non-negative.") + } if (any(is.infinite(pf_group))) { cli_abort(c( "`pf_group` must be finite.", @@ -264,15 +292,21 @@ sparsegl <- function( } else { # flmin = 1 if user define lambda flmin <- as.double(1) - if (any(lambda < 0)) cli_abort("`lambda` must be non-negative.") + if (any(lambda < 0)) { + cli_abort("`lambda` must be non-negative.") + } ulam <- as.double(rev(sort(lambda))) nlam <- as.integer(length(lambda)) } intr <- as.integer(intercept) ### check on upper/lower bounds - if (any(lower_bnd > 0)) cli_abort("`lower_bnd` must be non-positive.") - if (any(upper_bnd < 0)) cli_abort("`upper_bnd` must be non-negative.") + if (any(lower_bnd > 0)) { + cli_abort("`lower_bnd` must be non-positive.") + } + if (any(upper_bnd < 0)) { + cli_abort("`upper_bnd` must be non-negative.") + } lower_bnd[lower_bnd == -Inf] <- -9.9e30 upper_bnd[upper_bnd == Inf] <- 9.9e30 if (length(lower_bnd) < bn) { @@ -312,30 +346,100 @@ sparsegl <- function( i = "Estimating sparse group lasso without any offset. See {.fn sparsegl::sparsegl}." )) } - fit <- switch(family, + fit <- switch( + family, gaussian = sgl_ls( - bn, bs, ix, iy, nobs, nvars, x, y, pf_group, pf_sparse, - dfmax, pmax, nlam, flmin, ulam, eps, maxit, vnames, group, intr, - as.double(asparse), standardize, lower_bnd, upper_bnd + bn, + bs, + ix, + iy, + nobs, + nvars, + x, + y, + pf_group, + pf_sparse, + dfmax, + pmax, + nlam, + flmin, + ulam, + eps, + maxit, + vnames, + group, + intr, + as.double(asparse), + standardize, + lower_bnd, + upper_bnd ), binomial = sgl_logit( - bn, bs, ix, iy, nobs, nvars, x, y, pf_group, pf_sparse, - dfmax, pmax, nlam, flmin, ulam, eps, maxit, vnames, group, intr, - as.double(asparse), standardize, lower_bnd, upper_bnd + bn, + bs, + ix, + iy, + nobs, + nvars, + x, + y, + pf_group, + pf_sparse, + dfmax, + pmax, + nlam, + flmin, + ulam, + eps, + maxit, + vnames, + group, + intr, + as.double(asparse), + standardize, + lower_bnd, + upper_bnd ) ) } if (fam$check == "fam") { fit <- sgl_irwls( - bn, bs, ix, iy, nobs, nvars, x, y, pf_group, pf_sparse, - dfmax, pmax, nlam, flmin, ulam, eps, maxit, vnames, group, intr, - as.double(asparse), standardize, lower_bnd, upper_bnd, weights, - offset, fam$family, trace_it, warm + bn, + bs, + ix, + iy, + nobs, + nvars, + x, + y, + pf_group, + pf_sparse, + dfmax, + pmax, + nlam, + flmin, + ulam, + eps, + maxit, + vnames, + group, + intr, + as.double(asparse), + standardize, + lower_bnd, + upper_bnd, + weights, + offset, + fam$family, + trace_it, + warm ) } # output - if (is.null(lambda)) fit$lambda <- lamfix(fit$lambda) + if (is.null(lambda)) { + fit$lambda <- lamfix(fit$lambda) + } fit$call <- this.call fit$asparse <- asparse fit$nobs <- nobs diff --git a/tests/testthat/_snaps/assertions.md b/tests/testthat/_snaps/assertions.md new file mode 100644 index 0000000..6a6a86b --- /dev/null +++ b/tests/testthat/_snaps/assertions.md @@ -0,0 +1,8 @@ +# `group` vctr is sorted + + Code + sparsegl(X, y, group = group1, asparse = 0, pf_group = pf1) + Condition + Error in `sparsegl()`: + ! `group` must be sorted in increasing order. + diff --git a/tests/testthat/test-assertions.R b/tests/testthat/test-assertions.R new file mode 100644 index 0000000..6120fd1 --- /dev/null +++ b/tests/testthat/test-assertions.R @@ -0,0 +1,11 @@ +test_that("`group` vctr is sorted", { + X <- matrix(rnorm(20 * 5), nrow = 20) + beta <- c(1, 1, 0, 0, 1) + y <- X %*% beta + rnorm(20) + group1 <- c(3, 3, 3, 1, 2) + pf1 <- c(0, 2, 0) + expect_snapshot( + error = TRUE, + sparsegl(X, y, group = group1, asparse = 0, pf_group = pf1) + ) +})