diff --git a/tdigest/include/tdigest.hpp b/tdigest/include/tdigest.hpp index cc7898e3..2d3620b1 100644 --- a/tdigest/include/tdigest.hpp +++ b/tdigest/include/tdigest.hpp @@ -108,6 +108,7 @@ class tdigest { /** * Update this t-Digest with the given value + * NaN and infinity values are ignored * @param value to update the t-Digest with */ void update(T value); @@ -153,6 +154,7 @@ class tdigest { * Compute approximate normalized rank of the given value. * *

If the sketch is empty this throws std::runtime_error. + *

NaN value throw std::invalid_argument. * * @param value to be ranked * @return normalized rank (from 0 to 1 inclusive) diff --git a/tdigest/include/tdigest_impl.hpp b/tdigest/include/tdigest_impl.hpp index b8fab38d..065e3ef1 100644 --- a/tdigest/include/tdigest_impl.hpp +++ b/tdigest/include/tdigest_impl.hpp @@ -23,12 +23,35 @@ #include #include #include +#include #include "common_defs.hpp" #include "memory_operations.hpp" namespace datasketches { +template +inline void check_not_nan(T value, const char* name) { + if (std::isnan(value)) { + throw std::invalid_argument(std::string(name) + " must not be NaN"); + } +} + +template +inline void check_not_infinite(T value, const char* name) { + if (std::isinf(value)) { + throw std::invalid_argument(std::string(name) + " must not be infinite"); + } +} + +template +inline void check_non_zero(T value, const char* name) { + static_assert(std::is_arithmetic::value, "T must be an arithmetic type"); + if (value == 0) { + throw std::invalid_argument(std::string(name) + " must not be zero"); + } +} + template tdigest::tdigest(uint16_t k, const A& allocator): tdigest(false, k, std::numeric_limits::infinity(), -std::numeric_limits::infinity(), vector_centroid(allocator), 0, vector_t(allocator)) @@ -37,6 +60,7 @@ tdigest(false, k, std::numeric_limits::infinity(), -std::numeric_limits::i template void tdigest::update(T value) { if (std::isnan(value)) return; + if (std::isinf(value)) return; if (buffer_.size() == centroids_capacity_ * BUFFER_MULTIPLIER) compress(); buffer_.push_back(value); min_ = std::min(min_, value); @@ -400,6 +424,8 @@ tdigest tdigest::deserialize(std::istream& is, const A& allocator) { const bool reverse_merge = flags_byte & (1 << flags::REVERSE_MERGE); if (is_single_value) { const T value = read(is); + check_not_nan(value, "single_value"); + check_not_infinite(value, "single_value"); return tdigest(reverse_merge, k, value, value, vector_centroid(1, centroid(value, 1), allocator), 1, vector_t(allocator)); } @@ -408,12 +434,26 @@ tdigest tdigest::deserialize(std::istream& is, const A& allocator) { const T min = read(is); const T max = read(is); + check_not_nan(min, "min"); + check_not_infinite(min, "min"); + check_not_nan(max, "max"); + check_not_infinite(max, "max"); vector_centroid centroids(num_centroids, centroid(0, 0), allocator); if (num_centroids > 0) read(is, centroids.data(), num_centroids * sizeof(centroid)); vector_t buffer(num_buffered, 0, allocator); if (num_buffered > 0) read(is, buffer.data(), num_buffered * sizeof(T)); uint64_t weight = 0; - for (const auto& c: centroids) weight += c.get_weight(); + for (const auto& c: centroids) { + check_not_nan(c.get_mean(), "centroid mean"); + check_not_infinite(c.get_mean(), "centroid mean"); + check_non_zero(c.get_weight(), "centroid weight"); + + weight += c.get_weight(); + } + for (const auto& value: buffer) { + check_not_nan(value, "buffered_value"); + check_not_infinite(value, "buffered_value"); + } return tdigest(reverse_merge, k, min, max, std::move(centroids), weight, std::move(buffer)); } @@ -451,6 +491,8 @@ tdigest tdigest::deserialize(const void* bytes, size_t size, const A ensure_minimum_memory(end_ptr - ptr, sizeof(T)); T value; ptr += copy_from_mem(ptr, value); + check_not_nan(value, "single_value"); + check_not_infinite(value, "single_value"); return tdigest(reverse_merge, k, value, value, vector_centroid(1, centroid(value, 1), allocator), 1, vector_t(allocator)); } @@ -465,12 +507,26 @@ tdigest tdigest::deserialize(const void* bytes, size_t size, const A ptr += copy_from_mem(ptr, min); T max; ptr += copy_from_mem(ptr, max); + check_not_nan(min, "min"); + check_not_infinite(min, "min"); + check_not_nan(max, "max"); + check_not_infinite(max, "max"); vector_centroid centroids(num_centroids, centroid(0, 0), allocator); if (num_centroids > 0) ptr += copy_from_mem(ptr, centroids.data(), num_centroids * sizeof(centroid)); vector_t buffer(num_buffered, 0, allocator); if (num_buffered > 0) copy_from_mem(ptr, buffer.data(), num_buffered * sizeof(T)); uint64_t weight = 0; - for (const auto& c: centroids) weight += c.get_weight(); + for (const auto& c: centroids) { + check_not_nan(c.get_mean(), "centroid mean"); + check_not_infinite(c.get_mean(), "centroid mean"); + check_non_zero(c.get_weight(), "centroid weight"); + + weight += c.get_weight(); + } + for (const auto& value: buffer) { + check_not_nan(value, "buffered_value"); + check_not_infinite(value, "buffered_value"); + } return tdigest(reverse_merge, k, min, max, std::move(centroids), weight, std::move(buffer)); } @@ -487,13 +543,24 @@ tdigest tdigest::deserialize_compat(std::istream& is, const A& alloc if (type == COMPAT_DOUBLE) { // compatibility with asBytes() const auto min = read_big_endian(is); const auto max = read_big_endian(is); + check_not_nan(min, "min"); + check_not_infinite(min, "min"); + check_not_nan(max, "max"); + check_not_infinite(max, "max"); const auto k = static_cast(read_big_endian(is)); const auto num_centroids = read_big_endian(is); vector_centroid centroids(num_centroids, centroid(0, 0), allocator); uint64_t total_weight = 0; for (auto& c: centroids) { - const W weight = static_cast(read_big_endian(is)); + const auto weight_double = read_big_endian(is); + check_not_nan(weight_double, "centroid weight"); + check_not_infinite(weight_double, "centroid weight"); + check_non_zero(weight_double, "centroid weight"); + const auto mean = read_big_endian(is); + check_not_nan(mean, "centroid mean"); + check_not_infinite(mean, "centroid mean"); + const W weight = static_cast(weight_double); c = centroid(mean, weight); total_weight += weight; } @@ -502,6 +569,10 @@ tdigest tdigest::deserialize_compat(std::istream& is, const A& alloc // COMPAT_FLOAT: compatibility with asSmallBytes() const auto min = read_big_endian(is); // reference implementation uses doubles for min and max const auto max = read_big_endian(is); + check_not_nan(min, "min"); + check_not_infinite(min, "min"); + check_not_nan(max, "max"); + check_not_infinite(max, "max"); const auto k = static_cast(read_big_endian(is)); // reference implementation stores capacities of the array of centroids and the buffer as shorts // they can be derived from k in the constructor @@ -510,8 +581,13 @@ tdigest tdigest::deserialize_compat(std::istream& is, const A& alloc vector_centroid centroids(num_centroids, centroid(0, 0), allocator); uint64_t total_weight = 0; for (auto& c: centroids) { - const W weight = static_cast(read_big_endian(is)); + const auto weight_float = read_big_endian(is); + check_not_nan(weight_float, "centroid weight"); + check_not_infinite(weight_float, "centroid weight"); const auto mean = read_big_endian(is); + check_not_nan(mean, "centroid mean"); + check_not_infinite(mean, "centroid mean"); + const W weight = static_cast(weight_float); c = centroid(mean, weight); total_weight += weight; } @@ -538,6 +614,10 @@ tdigest tdigest::deserialize_compat(const void* bytes, size_t size, double max; ptr += copy_from_mem(ptr, max); max = byteswap(max); + check_not_nan(min, "min"); + check_not_infinite(min, "min"); + check_not_nan(max, "max"); + check_not_infinite(max, "max"); double k_double; ptr += copy_from_mem(ptr, k_double); const uint16_t k = static_cast(byteswap(k_double)); @@ -554,6 +634,10 @@ tdigest tdigest::deserialize_compat(const void* bytes, size_t size, double mean; ptr += copy_from_mem(ptr, mean); mean = byteswap(mean); + check_not_nan(weight, "centroid weight"); + check_not_infinite(weight, "centroid weight"); + check_not_nan(mean, "centroid mean"); + check_not_infinite(mean, "centroid mean"); c = centroid(mean, static_cast(weight)); total_weight += static_cast(weight); } @@ -567,6 +651,10 @@ tdigest tdigest::deserialize_compat(const void* bytes, size_t size, double max; ptr += copy_from_mem(ptr, max); max = byteswap(max); + check_not_nan(min, "min"); + check_not_infinite(min, "min"); + check_not_nan(max, "max"); + check_not_infinite(max, "max"); float k_float; ptr += copy_from_mem(ptr, k_float); const uint16_t k = static_cast(byteswap(k_float)); @@ -586,6 +674,10 @@ tdigest tdigest::deserialize_compat(const void* bytes, size_t size, float mean; ptr += copy_from_mem(ptr, mean); mean = byteswap(mean); + check_not_nan(weight, "centroid weight"); + check_not_infinite(weight, "centroid weight"); + check_not_nan(mean, "centroid mean"); + check_not_infinite(mean, "centroid mean"); c = centroid(mean, static_cast(weight)); total_weight += static_cast(weight); } diff --git a/tdigest/test/tdigest_test.cpp b/tdigest/test/tdigest_test.cpp index 9f92094d..07d6185f 100644 --- a/tdigest/test/tdigest_test.cpp +++ b/tdigest/test/tdigest_test.cpp @@ -18,13 +18,36 @@ */ #include +#include #include #include +#include #include "tdigest.hpp" namespace datasketches { +namespace { +constexpr size_t header_size = 8; +constexpr size_t counts_size = 8; +constexpr size_t min_offset = header_size + counts_size; +constexpr size_t max_offset = min_offset + sizeof(double); +constexpr size_t first_centroid_mean_offset = min_offset + sizeof(double) * 2; +constexpr size_t first_centroid_weight_offset = first_centroid_mean_offset + sizeof(double); +constexpr size_t first_buffered_value_offset = first_centroid_mean_offset; +constexpr size_t single_value_offset = header_size; + +template +void write_bytes(std::vector& bytes, size_t offset, T value) { + std::memcpy(bytes.data() + offset, &value, sizeof(T)); +} + +template +void write_bytes(std::string& data, size_t offset, T value) { + std::memcpy(&data[offset], &value, sizeof(T)); +} +} // namespace + TEST_CASE("empty", "[tdigest]") { tdigest_double td(10); // std::cout << td.to_string(); @@ -470,4 +493,112 @@ TEST_CASE("iterate centroids", "[tdigest]") { REQUIRE(td.get_total_weight() == total_weight); } +TEST_CASE("update rejects positive infinity", "[tdigest]") { + tdigest_double td(100); + td.update(1.0); + td.update(2.0); + td.update(std::numeric_limits::infinity()); + REQUIRE(td.get_total_weight() == 2); + REQUIRE(td.get_max_value() == 2.0); +} + +TEST_CASE("update rejects negative infinity", "[tdigest]") { + tdigest_double td(100); + td.update(1.0); + td.update(2.0); + td.update(-std::numeric_limits::infinity()); + REQUIRE(td.get_total_weight() == 2); + REQUIRE(td.get_min_value() == 1.0); +} + +TEST_CASE("deserialize bytes rejects NaN single value", "[tdigest]") { + tdigest_double td(100); + td.update(1.0); + auto bytes = td.serialize(); + write_bytes(bytes, single_value_offset, std::numeric_limits::quiet_NaN()); + REQUIRE_THROWS_AS(tdigest_double::deserialize(bytes.data(), bytes.size()), std::invalid_argument); +} + +TEST_CASE("deserialize stream rejects infinity min", "[tdigest]") { + tdigest_double td(100); + td.update(1.0); + td.update(2.0); + td.update(3.0); + auto bytes = td.serialize(); + std::string data(reinterpret_cast(bytes.data()), bytes.size()); + write_bytes(data, min_offset, std::numeric_limits::infinity()); + std::istringstream is(data, std::ios::binary); + REQUIRE_THROWS_AS(tdigest_double::deserialize(is), std::invalid_argument); +} + +TEST_CASE("deserialize bytes rejects NaN centroid mean", "[tdigest]") { + tdigest_double td(100); + for (int i = 0; i < 10; ++i) td.update(i); + auto bytes = td.serialize(); + write_bytes(bytes, first_centroid_mean_offset, std::numeric_limits::quiet_NaN()); + REQUIRE_THROWS_AS(tdigest_double::deserialize(bytes.data(), bytes.size()), std::invalid_argument); +} + +TEST_CASE("deserialize bytes rejects NaN buffered value", "[tdigest]") { + tdigest_double td(100); + td.update(1.0); + td.update(2.0); + auto bytes = td.serialize(0, true); + write_bytes(bytes, first_buffered_value_offset, std::numeric_limits::quiet_NaN()); + REQUIRE_THROWS_AS(tdigest_double::deserialize(bytes.data(), bytes.size()), std::invalid_argument); +} + +TEST_CASE("deserialize bytes rejects infinity single value", "[tdigest]") { + tdigest_double td(100); + td.update(1.0); + auto bytes = td.serialize(); + write_bytes(bytes, single_value_offset, std::numeric_limits::infinity()); + REQUIRE_THROWS_AS(tdigest_double::deserialize(bytes.data(), bytes.size()), std::invalid_argument); +} + +TEST_CASE("deserialize bytes rejects NaN max", "[tdigest]") { + tdigest_double td(100); + td.update(1.0); + td.update(2.0); + auto bytes = td.serialize(); + write_bytes(bytes, max_offset, std::numeric_limits::quiet_NaN()); + REQUIRE_THROWS_AS(tdigest_double::deserialize(bytes.data(), bytes.size()), std::invalid_argument); +} + +TEST_CASE("deserialize bytes rejects infinity max", "[tdigest]") { + tdigest_double td(100); + td.update(1.0); + td.update(2.0); + auto bytes = td.serialize(); + write_bytes(bytes, max_offset, std::numeric_limits::infinity()); + REQUIRE_THROWS_AS(tdigest_double::deserialize(bytes.data(), bytes.size()), std::invalid_argument); +} + +TEST_CASE("deserialize bytes rejects infinity buffered value", "[tdigest]") { + tdigest_double td(100); + td.update(1.0); + td.update(2.0); + auto bytes = td.serialize(0, true); + write_bytes(bytes, first_buffered_value_offset, std::numeric_limits::infinity()); + REQUIRE_THROWS_AS(tdigest_double::deserialize(bytes.data(), bytes.size()), std::invalid_argument); +} + +TEST_CASE("deserialize bytes rejects zero centroid weight", "[tdigest]") { + tdigest_double td(100); + for (int i = 0; i < 10; ++i) td.update(i); + auto bytes = td.serialize(); + write_bytes(bytes, first_centroid_weight_offset, static_cast(0)); + REQUIRE_THROWS_AS(tdigest_double::deserialize(bytes.data(), bytes.size()), std::invalid_argument); +} + +TEST_CASE("deserialize stream rejects zero centroid weight", "[tdigest]") { + tdigest_double td(100); + for (int i = 0; i < 10; ++i) td.update(i); + auto bytes = td.serialize(); + std::string data(reinterpret_cast(bytes.data()), bytes.size()); + write_bytes(data, first_centroid_weight_offset, static_cast(0)); + std::istringstream is(data, std::ios::binary); + REQUIRE_THROWS_AS(tdigest_double::deserialize(is), std::invalid_argument); +} + } /* namespace datasketches */