Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ channels:
dependencies:
- python <3.13
- cmake >=3.24
- cuda-version 12.*
- gxx >9
# - cuda-version 13.* # uncomment for a specific cuda version
- gxx >9,<15
- cuda-libraries-dev
- cuda-nvcc
- make
Expand Down
1 change: 1 addition & 0 deletions solvers/NBody/extra/NbodyRPY.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "hydrodynamicKernels.cuh"
#include "interface.h"
#include "vector.cuh"
#include <thrust/extrema.h>

namespace nbody_rpy {

Expand Down
106 changes: 56 additions & 50 deletions solvers/NBody/extra/vector.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ using real = libmobility::real;
#ifdef DOUBLE_PRECISION
using real2 = double2;
using real3 = double3;
using real4 = double4;
using real4 = double4_16a;
#else
using real2 = float2;
using real3 = float3;
Expand Down Expand Up @@ -297,7 +297,7 @@ VECATTR real4 make_real4(real x, real y, real z, real w) {
#ifdef SINGLE_PRECISION
return make_float4(x, y, z, w);
#else
return make_double4(x, y, z, w);
return make_double4_16a(x, y, z, w);
#endif
}

Expand Down Expand Up @@ -344,7 +344,7 @@ VECATTR real3 make_real3(real3 a) { return make_real3(a.x, a.y, a.z); }

#ifdef SINGLE_PRECISION
VECATTR real3 make_real3(double3 a) { return make_real3(a.x, a.y, a.z); }
VECATTR real3 make_real3(double4 a) { return make_real3(a.x, a.y, a.z); }
VECATTR real3 make_real3(double4_16a a) { return make_real3(a.x, a.y, a.z); }
#else
template <typename T> VECATTR real3 make_real3(float2 a, T b) {
return make_real3(a.x, a.y, b);
Expand Down Expand Up @@ -402,115 +402,121 @@ VECATTR double3 make_double3(nbody_rpy::real4 a) {
return make_double3(a.x, a.y, a.z);
}
#endif
VECATTR float4 make_float4(double4 a) {
VECATTR float4 make_float4(double4_16a a) {
return make_float4(float(a.x), float(a.y), float(a.z), float(a.w));
}

VECATTR double4 make_double4(double s) { return make_double4(s, s, s, s); }
VECATTR double4 make_double4(double3 a) {
return make_double4(a.x, a.y, a.z, 0.0f);
VECATTR double4_16a make_double4_16a(double s) {
return make_double4_16a(s, s, s, s);
}
VECATTR double4 make_double4(double3 a, double w) {
return make_double4(a.x, a.y, a.z, w);
VECATTR double4_16a make_double4_16a(double3 a) {
return make_double4_16a(a.x, a.y, a.z, 0.0f);
}
VECATTR double4 make_double4(int4 a) {
return make_double4(double(a.x), double(a.y), double(a.z), double(a.w));
VECATTR double4_16a make_double4_16a(double3 a, double w) {
return make_double4_16a(a.x, a.y, a.z, w);
}
VECATTR double4 make_double4(uint4 a) {
return make_double4(double(a.x), double(a.y), double(a.z), double(a.w));
VECATTR double4_16a make_double4_16a(int4 a) {
return make_double4_16a(double(a.x), double(a.y), double(a.z), double(a.w));
}
VECATTR double4 make_double4(float4 a) {
return make_double4(double(a.x), double(a.y), double(a.z), double(a.w));
VECATTR double4_16a make_double4_16a(uint4 a) {
return make_double4_16a(double(a.x), double(a.y), double(a.z), double(a.w));
}
VECATTR double4_16a make_double4_16a(float4 a) {
return make_double4_16a(double(a.x), double(a.y), double(a.z), double(a.w));
}

//////DOUBLE4///////////////
VECATTR double4 operator+(const double4 &a, const double4 &b) {
return make_double4(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w);
//////double4_16a///////////////
VECATTR double4_16a operator+(const double4_16a &a, const double4_16a &b) {
return make_double4_16a(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w);
}
VECATTR void operator+=(double4 &a, const double4 &b) {
VECATTR void operator+=(double4_16a &a, const double4_16a &b) {
a.x += b.x;
a.y += b.y;
a.z += b.z;
a.w += b.w;
}
VECATTR double4 operator+(const double4 &a, const double &b) {
return make_double4(a.x + b, a.y + b, a.z + b, a.w + b);
VECATTR double4_16a operator+(const double4_16a &a, const double &b) {
return make_double4_16a(a.x + b, a.y + b, a.z + b, a.w + b);
}
VECATTR double4_16a operator+(const double &b, const double4_16a &a) {
return a + b;
}
VECATTR double4 operator+(const double &b, const double4 &a) { return a + b; }
VECATTR void operator+=(double4 &a, const double &b) {
VECATTR void operator+=(double4_16a &a, const double &b) {
a.x += b;
a.y += b;
a.z += b;
a.w += b;
}

VECATTR double4 operator-(const double4 &a, const double4 &b) {
return make_double4(a.x - b.x, a.y - b.y, a.z - b.z, a.w - b.w);
VECATTR double4_16a operator-(const double4_16a &a, const double4_16a &b) {
return make_double4_16a(a.x - b.x, a.y - b.y, a.z - b.z, a.w - b.w);
}
VECATTR void operator-=(double4 &a, const double4 &b) {
VECATTR void operator-=(double4_16a &a, const double4_16a &b) {
a.x -= b.x;
a.y -= b.y;
a.z -= b.z;
a.w -= b.w;
}
VECATTR double4 operator-(const double4 &a, const double &b) {
return make_double4(a.x - b, a.y - b, a.z - b, a.w - b);
VECATTR double4_16a operator-(const double4_16a &a, const double &b) {
return make_double4_16a(a.x - b, a.y - b, a.z - b, a.w - b);
}
VECATTR double4 operator-(const double &b, const double4 &a) {
return make_double4(b - a.x, b - a.y, b - a.z, b - a.w);
VECATTR double4_16a operator-(const double &b, const double4_16a &a) {
return make_double4_16a(b - a.x, b - a.y, b - a.z, b - a.w);
}
VECATTR void operator-=(double4 &a, const double &b) {
VECATTR void operator-=(double4_16a &a, const double &b) {
a.x -= b;
a.y -= b;
a.z -= b;
a.w -= b;
}
VECATTR double4 operator*(const double4 &a, const double4 &b) {
return make_double4(a.x * b.x, a.y * b.y, a.z * b.z, a.w * b.w);
VECATTR double4_16a operator*(const double4_16a &a, const double4_16a &b) {
return make_double4_16a(a.x * b.x, a.y * b.y, a.z * b.z, a.w * b.w);
}
VECATTR void operator*=(double4 &a, const double4 &b) {
VECATTR void operator*=(double4_16a &a, const double4_16a &b) {
a.x *= b.x;
a.y *= b.y;
a.z *= b.z;
a.w *= b.w;
}
VECATTR double4 operator*(const double4 &a, const double &b) {
return make_double4(a.x * b, a.y * b, a.z * b, a.w * b);
VECATTR double4_16a operator*(const double4_16a &a, const double &b) {
return make_double4_16a(a.x * b, a.y * b, a.z * b, a.w * b);
}
VECATTR double4_16a operator*(const double &b, const double4_16a &a) {
return a * b;
}
VECATTR double4 operator*(const double &b, const double4 &a) { return a * b; }
VECATTR void operator*=(double4 &a, const double &b) {
VECATTR void operator*=(double4_16a &a, const double &b) {
a.x *= b;
a.y *= b;
a.z *= b;
a.w *= b;
}
VECATTR double4 operator/(const double4 &a, const double4 &b) {
return make_double4(a.x / b.x, a.y / b.y, a.z / b.z, a.w / b.w);
VECATTR double4_16a operator/(const double4_16a &a, const double4_16a &b) {
return make_double4_16a(a.x / b.x, a.y / b.y, a.z / b.z, a.w / b.w);
}
VECATTR void operator/=(double4 &a, const double4 &b) {
VECATTR void operator/=(double4_16a &a, const double4_16a &b) {
a.x /= b.x;
a.y /= b.y;
a.z /= b.z;
a.w /= b.w;
}
VECATTR double4 operator/(const double4 &a, const double &b) {
VECATTR double4_16a operator/(const double4_16a &a, const double &b) {
return (1.0 / b) * a;
}
VECATTR double4 operator/(const double &b, const double4 &a) {
return make_double4(b / a.x, b / a.y, b / a.z, b / a.w);
VECATTR double4_16a operator/(const double &b, const double4_16a &a) {
return make_double4_16a(b / a.x, b / a.y, b / a.z, b / a.w);
}
VECATTR void operator/=(double4 &a, const double &b) { a *= 1.0 / b; }
VECATTR void operator/=(double4_16a &a, const double &b) { a *= 1.0 / b; }

VECATTR double dot(double4 a, double4 b) {
VECATTR double dot(double4_16a a, double4_16a b) {
return a.x * b.x + a.y * b.y + a.z * b.z + a.w * b.w;
}
VECATTR double length(double4 v) { return sqrt(dot(v, v)); }
VECATTR double4 normalize(double4 v) {
VECATTR double length(double4_16a v) { return sqrt(dot(v, v)); }
VECATTR double4_16a normalize(double4_16a v) {
double invLen = 1.0 / sqrt(dot(v, v));
return v * invLen;
}
VECATTR double4 floorf(double4 v) {
return make_double4(floor(v.x), floor(v.y), floor(v.z), floor(v.w));
VECATTR double4_16a floorf(double4_16a v) {
return make_double4_16a(floor(v.x), floor(v.y), floor(v.z), floor(v.w));
}

/////////////////////DOUBLE3///////////////////////////////
Expand Down
Loading