Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 125 additions & 29 deletions src/distribution/multivariate_normal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,93 @@ impl MultivariateNormal<Dyn> {
}
}

/// Check that a covariance is square, perfectly symmetric, and non-NaN
fn check_cov<D>(cov: &OMatrix<f64, D, D>) -> Result<(), MultivariateNormalError>
where
D: DimMin<D, Output = D>,
nalgebra::DefaultAllocator: nalgebra::allocator::Allocator<D>
+ nalgebra::allocator::Allocator<D, D>
+ nalgebra::allocator::Allocator<D>,
{
if !cov.is_square()
|| cov.lower_triangle() != cov.upper_triangle().transpose()
|| cov.iter().any(|f| f.is_nan())
{
Err(MultivariateNormalError::CovInvalid)
} else {
Ok(())
}
}

/// Check the mean, covariance, and cholesky decomposition for incompatibilities, and return all three.
///
/// Covariance and cholesky decomposition are computed from each other as necessary.
///
/// # Panics
/// If both the `cov` and `cholesky` arguments are `None`; at least one must be `Some(_)`.
fn normalize_constructor_arguments<D>(
mean: OVector<f64, D>,
covariance: Option<OMatrix<f64, D, D>>,
cholesky: Option<Cholesky<f64, D>>,
) -> Result<(OVector<f64, D>, OMatrix<f64, D, D>, Cholesky<f64, D>), MultivariateNormalError>
where
D: DimMin<D, Output = D>,
nalgebra::DefaultAllocator: nalgebra::allocator::Allocator<D>
+ nalgebra::allocator::Allocator<D, D>
+ nalgebra::allocator::Allocator<D>,
{
// Check that mean is valid
if mean.iter().any(|f| f.is_nan()) {
return Err(MultivariateNormalError::MeanInvalid);
}
let n = mean.shape_generic().0;

let cov: OMatrix<f64, D, D>;
let chol: Cholesky<f64, D>;

match (covariance, cholesky) {
(None, None) => {
panic!("Must pass either cov or cholesky to normalize_constructor_arguments")
}
(Some(c), None) => {
// Check covariance and compute Cholesky
check_cov(&c)?;
if c.shape_generic().0 != n {
return Err(MultivariateNormalError::DimensionMismatch);
}
chol = Cholesky::new(c.clone()).ok_or(MultivariateNormalError::CholeskyFailed)?;
cov = c;
}
(None, Some(ch)) => {
// Check cholesky and compute covariance
let ch_inner = ch.unpack_dirty();
if ch_inner.shape_generic().0 != n {
return Err(MultivariateNormalError::DimensionMismatch);
}
chol = Cholesky::pack_dirty(ch_inner);

let l = chol.l();
cov = l.clone() * l.transpose();
}
(Some(c), Some(ch)) => {
// Check both covariance and cholesky
check_cov(&c)?;
if c.shape_generic().0 != n {
return Err(MultivariateNormalError::DimensionMismatch);
}
cov = c;

let ch_inner = ch.unpack_dirty();
if ch_inner.shape_generic().0 != n {
return Err(MultivariateNormalError::DimensionMismatch);
}
chol = Cholesky::pack_dirty(ch_inner);
}
}

Ok((mean, cov, chol))
}

impl<D> MultivariateNormal<D>
where
D: DimMin<D, Output = D>,
Expand All @@ -172,37 +259,46 @@ where
mean: OVector<f64, D>,
cov: OMatrix<f64, D, D>,
) -> Result<Self, MultivariateNormalError> {
if mean.iter().any(|f| f.is_nan()) {
return Err(MultivariateNormalError::MeanInvalid);
}

if !cov.is_square()
|| cov.lower_triangle() != cov.upper_triangle().transpose()
|| cov.iter().any(|f| f.is_nan())
{
return Err(MultivariateNormalError::CovInvalid);
}
let (mean, cov, cholesky) = normalize_constructor_arguments(mean, Some(cov), None)?;
Ok(Self::new_unchecked(mean, cov, cholesky))
}

// Compare number of rows
if mean.shape_generic().0 != cov.shape_generic().0 {
return Err(MultivariateNormalError::DimensionMismatch);
}
/// Construct a new multivariate normal from a mean and an already-computed
/// Cholesky decomposition of the covariance matrix, using `nalgebra` types.
///
/// Unlike [`MultivariateNormal::new_from_nalgebra`], this does not require
/// the covariance matrix to be perfectly symmetric, since [`Cholesky`] is
/// created with only the lower diagonal.
///
/// # Errors
///
/// Returns an error if the mean has any `NaN` values or the
/// mean and cholesky decomposition have a different number of rows.
pub fn new_from_cholesky(
mean: OVector<f64, D>,
cholesky: Cholesky<f64, D>,
) -> Result<Self, MultivariateNormalError> {
let (mean, cov, cholesky) = normalize_constructor_arguments(mean, None, Some(cholesky))?;
Ok(Self::new_unchecked(mean, cov, cholesky))
}

// Store the Cholesky decomposition of the covariance matrix
// for sampling
match Cholesky::new(cov.clone()) {
None => Err(MultivariateNormalError::CholeskyFailed),
Some(cholesky_decomp) => {
let precision = cholesky_decomp.inverse();
Ok(MultivariateNormal {
// .unwrap() because prerequisites are already checked above
pdf_const: density_distribution_pdf_const(&mean, &cov).unwrap(),
cov_chol_decomp: cholesky_decomp.unpack(),
mu: mean,
cov,
precision,
})
}
/// Construct a multivariate normal without checking the compatibility of the arguments.
/// It is on the caller to ensure that they have the same shape and meet their respective
/// invariants.
fn new_unchecked(
mean: OVector<f64, D>,
cov: OMatrix<f64, D, D>,
cholesky: Cholesky<f64, D>,
) -> MultivariateNormal<D> {
// Grab precision
let precision = cholesky.inverse();

MultivariateNormal {
pdf_const: density_distribution_pdf_const(&mean, &cov).unwrap(),
cov_chol_decomp: cholesky.unpack(),
mu: mean,
cov,
precision,
}
}

Expand Down