diff --git a/src/distribution/multivariate_normal.rs b/src/distribution/multivariate_normal.rs index bf99ab4a..77be577b 100644 --- a/src/distribution/multivariate_normal.rs +++ b/src/distribution/multivariate_normal.rs @@ -153,6 +153,93 @@ impl MultivariateNormal { } } +/// Check that a covariance is square, perfectly symmetric, and non-NaN +fn check_cov(cov: &OMatrix) -> Result<(), MultivariateNormalError> +where + D: DimMin, + nalgebra::DefaultAllocator: nalgebra::allocator::Allocator + + nalgebra::allocator::Allocator + + nalgebra::allocator::Allocator, +{ + if !cov.is_square() + || cov.lower_triangle() != cov.upper_triangle().transpose() + || cov.iter().any(|f| f.is_nan()) + { + Err(MultivariateNormalError::CovInvalid) + } else { + Ok(()) + } +} + +/// Check the mean, covariance, and cholesky decomposition for incompatibilities, and return all three. +/// +/// Covariance and cholesky decomposition are computed from each other as necessary. +/// +/// # Panics +/// If both the `cov` and `cholesky` arguments are `None`; at least one must be `Some(_)`. +fn normalize_constructor_arguments( + mean: OVector, + covariance: Option>, + cholesky: Option>, +) -> Result<(OVector, OMatrix, Cholesky), MultivariateNormalError> +where + D: DimMin, + nalgebra::DefaultAllocator: nalgebra::allocator::Allocator + + nalgebra::allocator::Allocator + + nalgebra::allocator::Allocator, +{ + // Check that mean is valid + if mean.iter().any(|f| f.is_nan()) { + return Err(MultivariateNormalError::MeanInvalid); + } + let n = mean.shape_generic().0; + + let cov: OMatrix; + let chol: Cholesky; + + match (covariance, cholesky) { + (None, None) => { + panic!("Must pass either cov or cholesky to normalize_constructor_arguments") + } + (Some(c), None) => { + // Check covariance and compute Cholesky + check_cov(&c)?; + if c.shape_generic().0 != n { + return Err(MultivariateNormalError::DimensionMismatch); + } + chol = Cholesky::new(c.clone()).ok_or(MultivariateNormalError::CholeskyFailed)?; + cov = c; + } + (None, Some(ch)) => { + // Check cholesky and compute covariance + let ch_inner = ch.unpack_dirty(); + if ch_inner.shape_generic().0 != n { + return Err(MultivariateNormalError::DimensionMismatch); + } + chol = Cholesky::pack_dirty(ch_inner); + + let l = chol.l(); + cov = l.clone() * l.transpose(); + } + (Some(c), Some(ch)) => { + // Check both covariance and cholesky + check_cov(&c)?; + if c.shape_generic().0 != n { + return Err(MultivariateNormalError::DimensionMismatch); + } + cov = c; + + let ch_inner = ch.unpack_dirty(); + if ch_inner.shape_generic().0 != n { + return Err(MultivariateNormalError::DimensionMismatch); + } + chol = Cholesky::pack_dirty(ch_inner); + } + } + + Ok((mean, cov, chol)) +} + impl MultivariateNormal where D: DimMin, @@ -172,37 +259,46 @@ where mean: OVector, cov: OMatrix, ) -> Result { - if mean.iter().any(|f| f.is_nan()) { - return Err(MultivariateNormalError::MeanInvalid); - } - - if !cov.is_square() - || cov.lower_triangle() != cov.upper_triangle().transpose() - || cov.iter().any(|f| f.is_nan()) - { - return Err(MultivariateNormalError::CovInvalid); - } + let (mean, cov, cholesky) = normalize_constructor_arguments(mean, Some(cov), None)?; + Ok(Self::new_unchecked(mean, cov, cholesky)) + } - // Compare number of rows - if mean.shape_generic().0 != cov.shape_generic().0 { - return Err(MultivariateNormalError::DimensionMismatch); - } + /// Construct a new multivariate normal from a mean and an already-computed + /// Cholesky decomposition of the covariance matrix, using `nalgebra` types. + /// + /// Unlike [`MultivariateNormal::new_from_nalgebra`], this does not require + /// the covariance matrix to be perfectly symmetric, since [`Cholesky`] is + /// created with only the lower diagonal. + /// + /// # Errors + /// + /// Returns an error if the mean has any `NaN` values or the + /// mean and cholesky decomposition have a different number of rows. + pub fn new_from_cholesky( + mean: OVector, + cholesky: Cholesky, + ) -> Result { + let (mean, cov, cholesky) = normalize_constructor_arguments(mean, None, Some(cholesky))?; + Ok(Self::new_unchecked(mean, cov, cholesky)) + } - // Store the Cholesky decomposition of the covariance matrix - // for sampling - match Cholesky::new(cov.clone()) { - None => Err(MultivariateNormalError::CholeskyFailed), - Some(cholesky_decomp) => { - let precision = cholesky_decomp.inverse(); - Ok(MultivariateNormal { - // .unwrap() because prerequisites are already checked above - pdf_const: density_distribution_pdf_const(&mean, &cov).unwrap(), - cov_chol_decomp: cholesky_decomp.unpack(), - mu: mean, - cov, - precision, - }) - } + /// Construct a multivariate normal without checking the compatibility of the arguments. + /// It is on the caller to ensure that they have the same shape and meet their respective + /// invariants. + fn new_unchecked( + mean: OVector, + cov: OMatrix, + cholesky: Cholesky, + ) -> MultivariateNormal { + // Grab precision + let precision = cholesky.inverse(); + + MultivariateNormal { + pdf_const: density_distribution_pdf_const(&mean, &cov).unwrap(), + cov_chol_decomp: cholesky.unpack(), + mu: mean, + cov, + precision, } }