diff --git a/src/linalg/rarray.rs b/src/linalg/rarray.rs index faa4555..3f078fb 100644 --- a/src/linalg/rarray.rs +++ b/src/linalg/rarray.rs @@ -67,7 +67,7 @@ impl Dimension for D3 {} // Base array struct #[derive(Debug)] pub struct Rarray { - pub(crate) data: Vec, + pub(crate) data: Tec, pub(crate) shape: D } @@ -76,3 +76,93 @@ pub type Rarray1D = Rarray; pub type Rarray2D = Rarray; pub type Rarray3D = Rarray; +pub trait RarrayCreate { + fn new(data: T) -> V; + fn zeros(shape: T) -> V; + fn ones(shape: T) -> V; + fn random(shape: T) -> V; +} + +pub trait RarrayMul { + fn mul(one: T, other: V) -> S; +} + +pub trait RarrayAdd { + fn add(one: T, other: V) -> S; +} + +pub trait RarraySub { + fn sub(one: T, other: V) -> S; +} + +impl RarrayMul for Rarray2D { + fn mul(one: Rarray1D, other: Rarray2D) -> Rarray1D { + let mut major = 1; + if one.shape.width > one.shape.height { + assert_eq!(one.shape.width, other.shape.height, "Rarray shape mismatch"); + major = one.shape.width; + } else { + assert_eq!(one.shape.height, other.shape.height, "Rarray shape mismatch"); + major = one.shape.height; + } + + let result = Rarray1D { + shape: D1 { width: major, height: 1 }, + data: vec![0.; major] + }; + + result + } +} + +impl RarrayMul for Rarray2D { + fn mul(one: Rarray2D, other: Rarray1D) -> Rarray1D { + other + } +} + +impl RarrayMul for Rarray2D { + fn mul(one: Rarray2D, other: Rarray2D) -> Rarray2D { + one + } +} + +impl RarrayAdd for Rarray2D { + fn add(one: Rarray2D, other: Rarray2D) -> Rarray2D { + assert_eq!(one.shape.width, other.shape.width, "Rarray shape mismatch"); + assert_eq!(one.shape.height, other.shape.height, "Rarray shape mismatch"); + + let mut result = Rarray2D { + shape: one.shape, + data: vec![0; one.shape.width * one.shape.height] + }; + + for i in 0..one.shape.width { + for j in 0..one.shape.height { + result[[i, j]] = one[[i, j]] + other[[i, j]]; + } + } + + result + } +} + +impl RarraySub for Rarray2D { + fn sub(one: Rarray2D, other: Rarray2D) -> Rarray2D { + assert_eq!(one.shape.width, other.shape.width, "Rarray shape mismatch"); + assert_eq!(one.shape.height, other.shape.height, "Rarray shape mismatch"); + + let mut result = Rarray2D { + shape: one.shape, + data: vec![0; one.shape.width * one.shape.height] + }; + + for i in 0..one.shape.width { + for j in 0..one.shape.height { + result[[i, j]] = one[[i, j]] - other[[i, j]]; + } + } + + result + } +} diff --git a/src/linalg/ndrarray/ndrarray.rs b/tests/rarray2d_mul_tests similarity index 100% rename from src/linalg/ndrarray/ndrarray.rs rename to tests/rarray2d_mul_tests diff --git a/tests/rarray2d_mul_tests.rs b/tests/rarray2d_mul_tests.rs new file mode 100644 index 0000000..7f247b6 --- /dev/null +++ b/tests/rarray2d_mul_tests.rs @@ -0,0 +1,21 @@ +mod test { + use rumpy::linalg::rarray::{MatOper, Rarray1D, Rarray2D}; + use rstest::rstest; + + #[test] + #[should_panic] + fn rarray2d_mul(){ + let x = Rarray1D::new(&vec![1., 1.]); + let y = Rarray2D::new(&vec![vec![1., 1.], vec![1., 1.]]); + + let z = Rarray2D::mul(x, y); + + let x = Rarray1D::new(&vec![1., 1.]); + let y = Rarray2D::new(&vec![vec![1., 1.], vec![1., 1.]]); + let y1 = Rarray2D::new(&vec![vec![1., 1.], vec![1., 1.]]); + + let s = Rarray2D::mul(y, y1); + println!("{:?}", z); + println!("{:?}", s); + } +}