Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
extern crate core;

pub mod linalg;


use crate::linalg::rarray::{Dim, Rarray};

pub mod linalg;
1 change: 1 addition & 0 deletions src/linalg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ pub mod rarray1d_create;
mod rarray2d_impl;
pub mod rarray2d_ops;
mod dimension;
mod numeric_trait;
6 changes: 3 additions & 3 deletions src/linalg/dimension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ impl PartialEq<Self> for D1 {
}

fn ne(&self, other: &Self) -> bool {
self.width == other.width || self.height == other.height
self.width != other.width || self.height != other.height
}
}

Expand Down Expand Up @@ -60,7 +60,7 @@ impl PartialEq<Self> for D2 {
}

fn ne(&self, other: &Self) -> bool {
self.width == other.width || self.height == other.height
self.width != other.width || self.height != other.height
}
}

Expand Down Expand Up @@ -93,7 +93,7 @@ impl PartialEq<Self> for D3 {
}

fn ne(&self, other: &Self) -> bool {
self.width == other.width || self.height == other.height || self.depth == other.depth
self.width != other.width || self.height != other.height || self.depth != other.depth
}
}

Expand Down
25 changes: 25 additions & 0 deletions src/linalg/numeric_trait.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Sub, SubAssign};

/// Specifies that a generic type is a numerical type
pub trait Numeric:
Add<Output = Self>
+ AddAssign
+ Sub<Output = Self>
+ SubAssign
+ Div<Output = Self>
+ DivAssign
+ Mul<Output = Self>
+ MulAssign
+ Copy
+ Default
+ ToString
+ From<i32>
{}

impl<T> Numeric for T where
T:
Add<Output = T> + AddAssign +
Sub<Output = T> + SubAssign +
Div<Output = T> + DivAssign +
Mul<Output = T> + MulAssign +
ToString + Copy + Default + From<i32> {}
61 changes: 30 additions & 31 deletions src/linalg/rarray.rs
Original file line number Diff line number Diff line change
@@ -1,26 +1,25 @@
use std::fmt::Debug;
use std::ops::{Add, Mul, MulAssign, Sub};
use std::ops::Neg;
use std::usize;
use rand::seq::IndexedRandom;

use crate::linalg::numeric_trait::Numeric;
pub(crate) use super::dimension::{Dim, D1, D2, D3};

// Base array struct
/// Base array struct
/// Consists of data field which is a 1 dimensional `Vec<T>` and of a shape field which is of type `Dim`
#[derive(Debug)]
pub struct Rarray<T, D> {
pub struct Rarray<T, D: Dim> {
pub(crate) data: Vec<T>,
pub(crate) shape: D
}

// Specific implementations
pub type Rarray1D = Rarray<f64, D1>;
pub type Rarray2D = Rarray<f64, D2>;
pub type Rarray3D = Rarray<f64, D3>;
pub type Rarray1D<T> = Rarray<T, D1>;
pub type Rarray2D<T> = Rarray<T, D2>;
pub type Rarray3D<T> = Rarray<T, D3>;

pub trait RarrayCreate<T, V, S> {
fn new(data: &V) -> Self;
fn zeros(shape: T) -> Self;
fn random(shape: T) -> Self;
fn fill(value: S, shape: T) -> Self;
}

Expand All @@ -40,9 +39,11 @@ pub trait RarraySub<T, V, S> {
fn sub(one: &T, other: &V) -> S;
}

impl RarrayMul<Rarray1D, Rarray2D, Rarray1D> for Rarray2D {
impl<T> RarrayMul<Rarray1D<T>, Rarray2D<T>, Rarray1D<T>> for Rarray2D<T> where
T: Numeric
{
/// Performs (1 x n) x (n x m) matrix multiplication
fn mul(one: &Rarray1D, other: &Rarray2D) -> Rarray1D {
fn mul(one: &Rarray1D<T>, other: &Rarray2D<T>) -> Rarray1D<T> {
let mut major: usize = 1;
if one.shape.width > one.shape.height {
assert_eq!(one.shape.width, other.shape.height, "Rarray shape mismatch");
Expand All @@ -54,24 +55,24 @@ impl RarrayMul<Rarray1D, Rarray2D, Rarray1D> for Rarray2D {

let mut result = Rarray1D {
shape: D1 { width: one.shape.width, height: 1 },
data: vec![0.; major]
data: vec![T::default(); major]
};

for i in 0..one.shape.width {
let mut sum: f64 = 0.;
for j in 0..major {
sum += one[j] * other[[i, j]];
result[i] += one[j] * other[[i, j]];
}
result[i] = sum;
}

result
}
}

impl RarrayMul<Rarray2D, Rarray1D, Rarray1D> for Rarray2D {
impl<T> RarrayMul<Rarray2D<T>, Rarray1D<T>, Rarray1D<T>> for Rarray2D<T> where
T: Numeric
{
/// Performs (n x m) x (m x 1) matrix multiplication
fn mul(one: &Rarray2D, other: &Rarray1D) -> Rarray1D {
fn mul(one: &Rarray2D<T>, other: &Rarray1D<T>) -> Rarray1D<T> {
let mut major: usize = 1;
if one.shape.width > one.shape.height {
assert_eq!(one.shape.width, other.shape.height, "Rarray shape mismatch");
Expand All @@ -83,38 +84,36 @@ impl RarrayMul<Rarray2D, Rarray1D, Rarray1D> for Rarray2D {

let mut result = Rarray1D {
shape: D1 { width: one.shape.height, height: 1 },
data: vec![0.; major]
data: vec![T::default(); major]
};

for i in 0..one.shape.height {
let mut sum: f64 = 0.;
for j in 0..major {
sum += one[[i, j]] * other[j];
result[i] += one[[i, j]] * other[j];
}
result[i] = sum;
}

result
}
}

impl RarrayMul<Rarray2D, Rarray2D, Rarray2D> for Rarray2D {
impl<T> RarrayMul<Rarray2D<T>, Rarray2D<T>, Rarray2D<T>> for Rarray2D<T> where
T: Numeric
{
/// Performs (n x m) x (m x l) matrix multiplication
fn mul(one: &Rarray2D, other: &Rarray2D) -> Rarray2D {
fn mul(one: &Rarray2D<T>, other: &Rarray2D<T>) -> Rarray2D<T> {
assert_eq!(one.shape.height, other.shape.width, "Rarray shape mismatch");

let mut result = Rarray2D {
shape: D2 { height: one.shape.height, width: other.shape.width },
data: vec![0.; one.shape.height * other.shape.width]
data: vec![T::default(); one.shape.height * other.shape.width]
};

for i in 0..one.shape.height {
for j in 0..other.shape.width {
let mut sum: f64 = 0.;
for k in 0..one.shape.width {
sum += one.data[i * one.shape.width + k] * other.data[j * other.shape.height + k];
result.data[i * result.shape.width + j * result.shape.height] += one.data[i * one.shape.width + k] * other.data[j * other.shape.height + k];
}
result.data[i * result.shape.width + j * result.shape.height] = sum;
}
}

Expand All @@ -123,7 +122,7 @@ impl RarrayMul<Rarray2D, Rarray2D, Rarray2D> for Rarray2D {
}

impl<T, D> RarrayScalMul<T, Rarray<T, D>> for Rarray<T, D> where
T: Copy + MulAssign,
T: Numeric,
D : Copy + Dim + Debug,
{
fn scal_mul(scal: T, rarray: &Rarray<T, D>) -> Rarray<T, D> {
Expand All @@ -141,7 +140,7 @@ impl<T, D> RarrayScalMul<T, Rarray<T, D>> for Rarray<T, D> where
}

impl<T, D> RarrayAdd<Rarray<T, D>, Rarray<T, D>, Rarray<T, D>> for Rarray<T, D> where
T : Copy + Add<Output = T> + Default,
T : Numeric,
D : Copy + Dim + Debug + Eq
{
fn add(one: &Rarray<T, D>, other: &Rarray<T, D>) -> Rarray<T, D> {
Expand All @@ -161,7 +160,7 @@ impl<T, D> RarrayAdd<Rarray<T, D>, Rarray<T, D>, Rarray<T, D>> for Rarray<T, D>
}

impl<T, D> RarraySub<Rarray<T, D>, Rarray<T, D>, Rarray<T, D>> for Rarray<T, D> where
T : Copy + Sub<Output = T> + Default,
T : Numeric,
D : Copy + Dim + Debug + Eq
{
fn sub(one: &Rarray<T, D>, other: &Rarray<T, D>) -> Rarray<T, D> {
Expand All @@ -178,4 +177,4 @@ impl<T, D> RarraySub<Rarray<T, D>, Rarray<T, D>, Rarray<T, D>> for Rarray<T, D>

result
}
}
}
Loading