diff --git a/tdigest/include/tdigest.hpp b/tdigest/include/tdigest.hpp
index cc7898e3..2d3620b1 100644
--- a/tdigest/include/tdigest.hpp
+++ b/tdigest/include/tdigest.hpp
@@ -108,6 +108,7 @@ class tdigest {
/**
* Update this t-Digest with the given value
+ * NaN and infinity values are ignored
* @param value to update the t-Digest with
*/
void update(T value);
@@ -153,6 +154,7 @@ class tdigest {
* Compute approximate normalized rank of the given value.
*
*
If the sketch is empty this throws std::runtime_error.
+ *
NaN value throw std::invalid_argument.
*
* @param value to be ranked
* @return normalized rank (from 0 to 1 inclusive)
diff --git a/tdigest/include/tdigest_impl.hpp b/tdigest/include/tdigest_impl.hpp
index b8fab38d..065e3ef1 100644
--- a/tdigest/include/tdigest_impl.hpp
+++ b/tdigest/include/tdigest_impl.hpp
@@ -23,12 +23,35 @@
#include
#include
#include
+#include
#include "common_defs.hpp"
#include "memory_operations.hpp"
namespace datasketches {
+template
+inline void check_not_nan(T value, const char* name) {
+ if (std::isnan(value)) {
+ throw std::invalid_argument(std::string(name) + " must not be NaN");
+ }
+}
+
+template
+inline void check_not_infinite(T value, const char* name) {
+ if (std::isinf(value)) {
+ throw std::invalid_argument(std::string(name) + " must not be infinite");
+ }
+}
+
+template
+inline void check_non_zero(T value, const char* name) {
+ static_assert(std::is_arithmetic::value, "T must be an arithmetic type");
+ if (value == 0) {
+ throw std::invalid_argument(std::string(name) + " must not be zero");
+ }
+}
+
template
tdigest::tdigest(uint16_t k, const A& allocator):
tdigest(false, k, std::numeric_limits::infinity(), -std::numeric_limits::infinity(), vector_centroid(allocator), 0, vector_t(allocator))
@@ -37,6 +60,7 @@ tdigest(false, k, std::numeric_limits::infinity(), -std::numeric_limits::i
template
void tdigest::update(T value) {
if (std::isnan(value)) return;
+ if (std::isinf(value)) return;
if (buffer_.size() == centroids_capacity_ * BUFFER_MULTIPLIER) compress();
buffer_.push_back(value);
min_ = std::min(min_, value);
@@ -400,6 +424,8 @@ tdigest tdigest::deserialize(std::istream& is, const A& allocator) {
const bool reverse_merge = flags_byte & (1 << flags::REVERSE_MERGE);
if (is_single_value) {
const T value = read(is);
+ check_not_nan(value, "single_value");
+ check_not_infinite(value, "single_value");
return tdigest(reverse_merge, k, value, value, vector_centroid(1, centroid(value, 1), allocator), 1, vector_t(allocator));
}
@@ -408,12 +434,26 @@ tdigest tdigest::deserialize(std::istream& is, const A& allocator) {
const T min = read(is);
const T max = read(is);
+ check_not_nan(min, "min");
+ check_not_infinite(min, "min");
+ check_not_nan(max, "max");
+ check_not_infinite(max, "max");
vector_centroid centroids(num_centroids, centroid(0, 0), allocator);
if (num_centroids > 0) read(is, centroids.data(), num_centroids * sizeof(centroid));
vector_t buffer(num_buffered, 0, allocator);
if (num_buffered > 0) read(is, buffer.data(), num_buffered * sizeof(T));
uint64_t weight = 0;
- for (const auto& c: centroids) weight += c.get_weight();
+ for (const auto& c: centroids) {
+ check_not_nan(c.get_mean(), "centroid mean");
+ check_not_infinite(c.get_mean(), "centroid mean");
+ check_non_zero(c.get_weight(), "centroid weight");
+
+ weight += c.get_weight();
+ }
+ for (const auto& value: buffer) {
+ check_not_nan(value, "buffered_value");
+ check_not_infinite(value, "buffered_value");
+ }
return tdigest(reverse_merge, k, min, max, std::move(centroids), weight, std::move(buffer));
}
@@ -451,6 +491,8 @@ tdigest tdigest::deserialize(const void* bytes, size_t size, const A
ensure_minimum_memory(end_ptr - ptr, sizeof(T));
T value;
ptr += copy_from_mem(ptr, value);
+ check_not_nan(value, "single_value");
+ check_not_infinite(value, "single_value");
return tdigest(reverse_merge, k, value, value, vector_centroid(1, centroid(value, 1), allocator), 1, vector_t(allocator));
}
@@ -465,12 +507,26 @@ tdigest tdigest::deserialize(const void* bytes, size_t size, const A
ptr += copy_from_mem(ptr, min);
T max;
ptr += copy_from_mem(ptr, max);
+ check_not_nan(min, "min");
+ check_not_infinite(min, "min");
+ check_not_nan(max, "max");
+ check_not_infinite(max, "max");
vector_centroid centroids(num_centroids, centroid(0, 0), allocator);
if (num_centroids > 0) ptr += copy_from_mem(ptr, centroids.data(), num_centroids * sizeof(centroid));
vector_t buffer(num_buffered, 0, allocator);
if (num_buffered > 0) copy_from_mem(ptr, buffer.data(), num_buffered * sizeof(T));
uint64_t weight = 0;
- for (const auto& c: centroids) weight += c.get_weight();
+ for (const auto& c: centroids) {
+ check_not_nan(c.get_mean(), "centroid mean");
+ check_not_infinite(c.get_mean(), "centroid mean");
+ check_non_zero(c.get_weight(), "centroid weight");
+
+ weight += c.get_weight();
+ }
+ for (const auto& value: buffer) {
+ check_not_nan(value, "buffered_value");
+ check_not_infinite(value, "buffered_value");
+ }
return tdigest(reverse_merge, k, min, max, std::move(centroids), weight, std::move(buffer));
}
@@ -487,13 +543,24 @@ tdigest tdigest::deserialize_compat(std::istream& is, const A& alloc
if (type == COMPAT_DOUBLE) { // compatibility with asBytes()
const auto min = read_big_endian(is);
const auto max = read_big_endian(is);
+ check_not_nan(min, "min");
+ check_not_infinite(min, "min");
+ check_not_nan(max, "max");
+ check_not_infinite(max, "max");
const auto k = static_cast(read_big_endian(is));
const auto num_centroids = read_big_endian(is);
vector_centroid centroids(num_centroids, centroid(0, 0), allocator);
uint64_t total_weight = 0;
for (auto& c: centroids) {
- const W weight = static_cast(read_big_endian(is));
+ const auto weight_double = read_big_endian(is);
+ check_not_nan(weight_double, "centroid weight");
+ check_not_infinite(weight_double, "centroid weight");
+ check_non_zero(weight_double, "centroid weight");
+
const auto mean = read_big_endian(is);
+ check_not_nan(mean, "centroid mean");
+ check_not_infinite(mean, "centroid mean");
+ const W weight = static_cast(weight_double);
c = centroid(mean, weight);
total_weight += weight;
}
@@ -502,6 +569,10 @@ tdigest tdigest::deserialize_compat(std::istream& is, const A& alloc
// COMPAT_FLOAT: compatibility with asSmallBytes()
const auto min = read_big_endian(is); // reference implementation uses doubles for min and max
const auto max = read_big_endian(is);
+ check_not_nan(min, "min");
+ check_not_infinite(min, "min");
+ check_not_nan(max, "max");
+ check_not_infinite(max, "max");
const auto k = static_cast(read_big_endian(is));
// reference implementation stores capacities of the array of centroids and the buffer as shorts
// they can be derived from k in the constructor
@@ -510,8 +581,13 @@ tdigest tdigest::deserialize_compat(std::istream& is, const A& alloc
vector_centroid centroids(num_centroids, centroid(0, 0), allocator);
uint64_t total_weight = 0;
for (auto& c: centroids) {
- const W weight = static_cast(read_big_endian(is));
+ const auto weight_float = read_big_endian(is);
+ check_not_nan(weight_float, "centroid weight");
+ check_not_infinite(weight_float, "centroid weight");
const auto mean = read_big_endian(is);
+ check_not_nan(mean, "centroid mean");
+ check_not_infinite(mean, "centroid mean");
+ const W weight = static_cast(weight_float);
c = centroid(mean, weight);
total_weight += weight;
}
@@ -538,6 +614,10 @@ tdigest tdigest::deserialize_compat(const void* bytes, size_t size,
double max;
ptr += copy_from_mem(ptr, max);
max = byteswap(max);
+ check_not_nan(min, "min");
+ check_not_infinite(min, "min");
+ check_not_nan(max, "max");
+ check_not_infinite(max, "max");
double k_double;
ptr += copy_from_mem(ptr, k_double);
const uint16_t k = static_cast(byteswap(k_double));
@@ -554,6 +634,10 @@ tdigest tdigest::deserialize_compat(const void* bytes, size_t size,
double mean;
ptr += copy_from_mem(ptr, mean);
mean = byteswap(mean);
+ check_not_nan(weight, "centroid weight");
+ check_not_infinite(weight, "centroid weight");
+ check_not_nan(mean, "centroid mean");
+ check_not_infinite(mean, "centroid mean");
c = centroid(mean, static_cast(weight));
total_weight += static_cast(weight);
}
@@ -567,6 +651,10 @@ tdigest tdigest::deserialize_compat(const void* bytes, size_t size,
double max;
ptr += copy_from_mem(ptr, max);
max = byteswap(max);
+ check_not_nan(min, "min");
+ check_not_infinite(min, "min");
+ check_not_nan(max, "max");
+ check_not_infinite(max, "max");
float k_float;
ptr += copy_from_mem(ptr, k_float);
const uint16_t k = static_cast(byteswap(k_float));
@@ -586,6 +674,10 @@ tdigest tdigest::deserialize_compat(const void* bytes, size_t size,
float mean;
ptr += copy_from_mem(ptr, mean);
mean = byteswap(mean);
+ check_not_nan(weight, "centroid weight");
+ check_not_infinite(weight, "centroid weight");
+ check_not_nan(mean, "centroid mean");
+ check_not_infinite(mean, "centroid mean");
c = centroid(mean, static_cast(weight));
total_weight += static_cast(weight);
}
diff --git a/tdigest/test/tdigest_test.cpp b/tdigest/test/tdigest_test.cpp
index 9f92094d..07d6185f 100644
--- a/tdigest/test/tdigest_test.cpp
+++ b/tdigest/test/tdigest_test.cpp
@@ -18,13 +18,36 @@
*/
#include
+#include
#include
#include
+#include
#include "tdigest.hpp"
namespace datasketches {
+namespace {
+constexpr size_t header_size = 8;
+constexpr size_t counts_size = 8;
+constexpr size_t min_offset = header_size + counts_size;
+constexpr size_t max_offset = min_offset + sizeof(double);
+constexpr size_t first_centroid_mean_offset = min_offset + sizeof(double) * 2;
+constexpr size_t first_centroid_weight_offset = first_centroid_mean_offset + sizeof(double);
+constexpr size_t first_buffered_value_offset = first_centroid_mean_offset;
+constexpr size_t single_value_offset = header_size;
+
+template
+void write_bytes(std::vector& bytes, size_t offset, T value) {
+ std::memcpy(bytes.data() + offset, &value, sizeof(T));
+}
+
+template
+void write_bytes(std::string& data, size_t offset, T value) {
+ std::memcpy(&data[offset], &value, sizeof(T));
+}
+} // namespace
+
TEST_CASE("empty", "[tdigest]") {
tdigest_double td(10);
// std::cout << td.to_string();
@@ -470,4 +493,112 @@ TEST_CASE("iterate centroids", "[tdigest]") {
REQUIRE(td.get_total_weight() == total_weight);
}
+TEST_CASE("update rejects positive infinity", "[tdigest]") {
+ tdigest_double td(100);
+ td.update(1.0);
+ td.update(2.0);
+ td.update(std::numeric_limits::infinity());
+ REQUIRE(td.get_total_weight() == 2);
+ REQUIRE(td.get_max_value() == 2.0);
+}
+
+TEST_CASE("update rejects negative infinity", "[tdigest]") {
+ tdigest_double td(100);
+ td.update(1.0);
+ td.update(2.0);
+ td.update(-std::numeric_limits::infinity());
+ REQUIRE(td.get_total_weight() == 2);
+ REQUIRE(td.get_min_value() == 1.0);
+}
+
+TEST_CASE("deserialize bytes rejects NaN single value", "[tdigest]") {
+ tdigest_double td(100);
+ td.update(1.0);
+ auto bytes = td.serialize();
+ write_bytes(bytes, single_value_offset, std::numeric_limits::quiet_NaN());
+ REQUIRE_THROWS_AS(tdigest_double::deserialize(bytes.data(), bytes.size()), std::invalid_argument);
+}
+
+TEST_CASE("deserialize stream rejects infinity min", "[tdigest]") {
+ tdigest_double td(100);
+ td.update(1.0);
+ td.update(2.0);
+ td.update(3.0);
+ auto bytes = td.serialize();
+ std::string data(reinterpret_cast(bytes.data()), bytes.size());
+ write_bytes(data, min_offset, std::numeric_limits::infinity());
+ std::istringstream is(data, std::ios::binary);
+ REQUIRE_THROWS_AS(tdigest_double::deserialize(is), std::invalid_argument);
+}
+
+TEST_CASE("deserialize bytes rejects NaN centroid mean", "[tdigest]") {
+ tdigest_double td(100);
+ for (int i = 0; i < 10; ++i) td.update(i);
+ auto bytes = td.serialize();
+ write_bytes(bytes, first_centroid_mean_offset, std::numeric_limits::quiet_NaN());
+ REQUIRE_THROWS_AS(tdigest_double::deserialize(bytes.data(), bytes.size()), std::invalid_argument);
+}
+
+TEST_CASE("deserialize bytes rejects NaN buffered value", "[tdigest]") {
+ tdigest_double td(100);
+ td.update(1.0);
+ td.update(2.0);
+ auto bytes = td.serialize(0, true);
+ write_bytes(bytes, first_buffered_value_offset, std::numeric_limits::quiet_NaN());
+ REQUIRE_THROWS_AS(tdigest_double::deserialize(bytes.data(), bytes.size()), std::invalid_argument);
+}
+
+TEST_CASE("deserialize bytes rejects infinity single value", "[tdigest]") {
+ tdigest_double td(100);
+ td.update(1.0);
+ auto bytes = td.serialize();
+ write_bytes(bytes, single_value_offset, std::numeric_limits::infinity());
+ REQUIRE_THROWS_AS(tdigest_double::deserialize(bytes.data(), bytes.size()), std::invalid_argument);
+}
+
+TEST_CASE("deserialize bytes rejects NaN max", "[tdigest]") {
+ tdigest_double td(100);
+ td.update(1.0);
+ td.update(2.0);
+ auto bytes = td.serialize();
+ write_bytes(bytes, max_offset, std::numeric_limits::quiet_NaN());
+ REQUIRE_THROWS_AS(tdigest_double::deserialize(bytes.data(), bytes.size()), std::invalid_argument);
+}
+
+TEST_CASE("deserialize bytes rejects infinity max", "[tdigest]") {
+ tdigest_double td(100);
+ td.update(1.0);
+ td.update(2.0);
+ auto bytes = td.serialize();
+ write_bytes(bytes, max_offset, std::numeric_limits::infinity());
+ REQUIRE_THROWS_AS(tdigest_double::deserialize(bytes.data(), bytes.size()), std::invalid_argument);
+}
+
+TEST_CASE("deserialize bytes rejects infinity buffered value", "[tdigest]") {
+ tdigest_double td(100);
+ td.update(1.0);
+ td.update(2.0);
+ auto bytes = td.serialize(0, true);
+ write_bytes(bytes, first_buffered_value_offset, std::numeric_limits::infinity());
+ REQUIRE_THROWS_AS(tdigest_double::deserialize(bytes.data(), bytes.size()), std::invalid_argument);
+}
+
+TEST_CASE("deserialize bytes rejects zero centroid weight", "[tdigest]") {
+ tdigest_double td(100);
+ for (int i = 0; i < 10; ++i) td.update(i);
+ auto bytes = td.serialize();
+ write_bytes(bytes, first_centroid_weight_offset, static_cast(0));
+ REQUIRE_THROWS_AS(tdigest_double::deserialize(bytes.data(), bytes.size()), std::invalid_argument);
+}
+
+TEST_CASE("deserialize stream rejects zero centroid weight", "[tdigest]") {
+ tdigest_double td(100);
+ for (int i = 0; i < 10; ++i) td.update(i);
+ auto bytes = td.serialize();
+ std::string data(reinterpret_cast(bytes.data()), bytes.size());
+ write_bytes(data, first_centroid_weight_offset, static_cast(0));
+ std::istringstream is(data, std::ios::binary);
+ REQUIRE_THROWS_AS(tdigest_double::deserialize(is), std::invalid_argument);
+}
+
} /* namespace datasketches */