Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions tdigest/include/tdigest.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ class tdigest {

/**
* Update this t-Digest with the given value
* NaN and infinity values are ignored
* @param value to update the t-Digest with
*/
void update(T value);
Expand Down Expand Up @@ -153,6 +154,7 @@ class tdigest {
* Compute approximate normalized rank of the given value.
*
* <p>If the sketch is empty this throws std::runtime_error.
* <p>NaN value throw std::invalid_argument.
*
* @param value to be ranked
* @return normalized rank (from 0 to 1 inclusive)
Expand Down
100 changes: 96 additions & 4 deletions tdigest/include/tdigest_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,35 @@
#include <algorithm>
#include <cmath>
#include <sstream>
#include <type_traits>

#include "common_defs.hpp"
#include "memory_operations.hpp"

namespace datasketches {

template<typename T>
inline void check_not_nan(T value, const char* name) {
if (std::isnan(value)) {
throw std::invalid_argument(std::string(name) + " must not be NaN");
}
}

template<typename T>
inline void check_not_infinite(T value, const char* name) {
if (std::isinf(value)) {
throw std::invalid_argument(std::string(name) + " must not be infinite");
}
}

template<typename T>
inline void check_non_zero(T value, const char* name) {
static_assert(std::is_arithmetic<T>::value, "T must be an arithmetic type");
if (value == 0) {
throw std::invalid_argument(std::string(name) + " must not be zero");
}
}

template<typename T, typename A>
tdigest<T, A>::tdigest(uint16_t k, const A& allocator):
tdigest(false, k, std::numeric_limits<T>::infinity(), -std::numeric_limits<T>::infinity(), vector_centroid(allocator), 0, vector_t(allocator))
Expand All @@ -37,6 +60,7 @@ tdigest(false, k, std::numeric_limits<T>::infinity(), -std::numeric_limits<T>::i
template<typename T, typename A>
void tdigest<T, A>::update(T value) {
if (std::isnan(value)) return;
if (std::isinf(value)) return;
if (buffer_.size() == centroids_capacity_ * BUFFER_MULTIPLIER) compress();
buffer_.push_back(value);
min_ = std::min(min_, value);
Expand Down Expand Up @@ -400,6 +424,8 @@ tdigest<T, A> tdigest<T, A>::deserialize(std::istream& is, const A& allocator) {
const bool reverse_merge = flags_byte & (1 << flags::REVERSE_MERGE);
if (is_single_value) {
const T value = read<T>(is);
check_not_nan(value, "single_value");
check_not_infinite(value, "single_value");
return tdigest(reverse_merge, k, value, value, vector_centroid(1, centroid(value, 1), allocator), 1, vector_t(allocator));
}

Expand All @@ -408,12 +434,26 @@ tdigest<T, A> tdigest<T, A>::deserialize(std::istream& is, const A& allocator) {

const T min = read<T>(is);
const T max = read<T>(is);
check_not_nan(min, "min");
check_not_infinite(min, "min");
check_not_nan(max, "max");
check_not_infinite(max, "max");
vector_centroid centroids(num_centroids, centroid(0, 0), allocator);
if (num_centroids > 0) read(is, centroids.data(), num_centroids * sizeof(centroid));
vector_t buffer(num_buffered, 0, allocator);
if (num_buffered > 0) read(is, buffer.data(), num_buffered * sizeof(T));
uint64_t weight = 0;
for (const auto& c: centroids) weight += c.get_weight();
for (const auto& c: centroids) {
check_not_nan(c.get_mean(), "centroid mean");
check_not_infinite(c.get_mean(), "centroid mean");
check_non_zero(c.get_weight(), "centroid weight");

weight += c.get_weight();
}
for (const auto& value: buffer) {
check_not_nan(value, "buffered_value");
check_not_infinite(value, "buffered_value");
}
return tdigest(reverse_merge, k, min, max, std::move(centroids), weight, std::move(buffer));
}

Expand Down Expand Up @@ -451,6 +491,8 @@ tdigest<T, A> tdigest<T, A>::deserialize(const void* bytes, size_t size, const A
ensure_minimum_memory(end_ptr - ptr, sizeof(T));
T value;
ptr += copy_from_mem(ptr, value);
check_not_nan(value, "single_value");
check_not_infinite(value, "single_value");
return tdigest(reverse_merge, k, value, value, vector_centroid(1, centroid(value, 1), allocator), 1, vector_t(allocator));
}

Expand All @@ -465,12 +507,26 @@ tdigest<T, A> tdigest<T, A>::deserialize(const void* bytes, size_t size, const A
ptr += copy_from_mem(ptr, min);
T max;
ptr += copy_from_mem(ptr, max);
check_not_nan(min, "min");
check_not_infinite(min, "min");
check_not_nan(max, "max");
check_not_infinite(max, "max");
vector_centroid centroids(num_centroids, centroid(0, 0), allocator);
if (num_centroids > 0) ptr += copy_from_mem(ptr, centroids.data(), num_centroids * sizeof(centroid));
vector_t buffer(num_buffered, 0, allocator);
if (num_buffered > 0) copy_from_mem(ptr, buffer.data(), num_buffered * sizeof(T));
uint64_t weight = 0;
for (const auto& c: centroids) weight += c.get_weight();
for (const auto& c: centroids) {
check_not_nan(c.get_mean(), "centroid mean");
check_not_infinite(c.get_mean(), "centroid mean");
check_non_zero(c.get_weight(), "centroid weight");

weight += c.get_weight();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: check weight is not zero

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you. bded7aa

}
for (const auto& value: buffer) {
check_not_nan(value, "buffered_value");
check_not_infinite(value, "buffered_value");
}
return tdigest(reverse_merge, k, min, max, std::move(centroids), weight, std::move(buffer));
}

Expand All @@ -487,13 +543,24 @@ tdigest<T, A> tdigest<T, A>::deserialize_compat(std::istream& is, const A& alloc
if (type == COMPAT_DOUBLE) { // compatibility with asBytes()
const auto min = read_big_endian<double>(is);
const auto max = read_big_endian<double>(is);
check_not_nan(min, "min");
check_not_infinite(min, "min");
check_not_nan(max, "max");
check_not_infinite(max, "max");
const auto k = static_cast<uint16_t>(read_big_endian<double>(is));
const auto num_centroids = read_big_endian<uint32_t>(is);
vector_centroid centroids(num_centroids, centroid(0, 0), allocator);
uint64_t total_weight = 0;
for (auto& c: centroids) {
const W weight = static_cast<W>(read_big_endian<double>(is));
const auto weight_double = read_big_endian<double>(is);
check_not_nan(weight_double, "centroid weight");
check_not_infinite(weight_double, "centroid weight");
check_non_zero(weight_double, "centroid weight");

const auto mean = read_big_endian<double>(is);
check_not_nan(mean, "centroid mean");
check_not_infinite(mean, "centroid mean");
const W weight = static_cast<W>(weight_double);
c = centroid(mean, weight);
total_weight += weight;
}
Expand All @@ -502,6 +569,10 @@ tdigest<T, A> tdigest<T, A>::deserialize_compat(std::istream& is, const A& alloc
// COMPAT_FLOAT: compatibility with asSmallBytes()
const auto min = read_big_endian<double>(is); // reference implementation uses doubles for min and max
const auto max = read_big_endian<double>(is);
check_not_nan(min, "min");
check_not_infinite(min, "min");
check_not_nan(max, "max");
check_not_infinite(max, "max");
const auto k = static_cast<uint16_t>(read_big_endian<float>(is));
// reference implementation stores capacities of the array of centroids and the buffer as shorts
// they can be derived from k in the constructor
Expand All @@ -510,8 +581,13 @@ tdigest<T, A> tdigest<T, A>::deserialize_compat(std::istream& is, const A& alloc
vector_centroid centroids(num_centroids, centroid(0, 0), allocator);
uint64_t total_weight = 0;
for (auto& c: centroids) {
const W weight = static_cast<W>(read_big_endian<float>(is));
const auto weight_float = read_big_endian<float>(is);
check_not_nan(weight_float, "centroid weight");
check_not_infinite(weight_float, "centroid weight");
const auto mean = read_big_endian<float>(is);
check_not_nan(mean, "centroid mean");
check_not_infinite(mean, "centroid mean");
const W weight = static_cast<W>(weight_float);
c = centroid(mean, weight);
total_weight += weight;
}
Expand All @@ -538,6 +614,10 @@ tdigest<T, A> tdigest<T, A>::deserialize_compat(const void* bytes, size_t size,
double max;
ptr += copy_from_mem(ptr, max);
max = byteswap(max);
check_not_nan(min, "min");
check_not_infinite(min, "min");
check_not_nan(max, "max");
check_not_infinite(max, "max");
double k_double;
ptr += copy_from_mem(ptr, k_double);
const uint16_t k = static_cast<uint16_t>(byteswap(k_double));
Expand All @@ -554,6 +634,10 @@ tdigest<T, A> tdigest<T, A>::deserialize_compat(const void* bytes, size_t size,
double mean;
ptr += copy_from_mem(ptr, mean);
mean = byteswap(mean);
check_not_nan(weight, "centroid weight");
check_not_infinite(weight, "centroid weight");
check_not_nan(mean, "centroid mean");
check_not_infinite(mean, "centroid mean");
c = centroid(mean, static_cast<W>(weight));
total_weight += static_cast<uint64_t>(weight);
}
Expand All @@ -567,6 +651,10 @@ tdigest<T, A> tdigest<T, A>::deserialize_compat(const void* bytes, size_t size,
double max;
ptr += copy_from_mem(ptr, max);
max = byteswap(max);
check_not_nan(min, "min");
check_not_infinite(min, "min");
check_not_nan(max, "max");
check_not_infinite(max, "max");
float k_float;
ptr += copy_from_mem(ptr, k_float);
const uint16_t k = static_cast<uint16_t>(byteswap(k_float));
Expand All @@ -586,6 +674,10 @@ tdigest<T, A> tdigest<T, A>::deserialize_compat(const void* bytes, size_t size,
float mean;
ptr += copy_from_mem(ptr, mean);
mean = byteswap(mean);
check_not_nan(weight, "centroid weight");
check_not_infinite(weight, "centroid weight");
check_not_nan(mean, "centroid mean");
check_not_infinite(mean, "centroid mean");
c = centroid(mean, static_cast<W>(weight));
total_weight += static_cast<uint64_t>(weight);
}
Expand Down
131 changes: 131 additions & 0 deletions tdigest/test/tdigest_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,36 @@
*/

#include <catch2/catch.hpp>
#include <cstring>
#include <iostream>
#include <fstream>
#include <sstream>

#include "tdigest.hpp"

namespace datasketches {

namespace {
constexpr size_t header_size = 8;
constexpr size_t counts_size = 8;
constexpr size_t min_offset = header_size + counts_size;
constexpr size_t max_offset = min_offset + sizeof(double);
constexpr size_t first_centroid_mean_offset = min_offset + sizeof(double) * 2;
constexpr size_t first_centroid_weight_offset = first_centroid_mean_offset + sizeof(double);
constexpr size_t first_buffered_value_offset = first_centroid_mean_offset;
constexpr size_t single_value_offset = header_size;

template <typename T>
void write_bytes(std::vector<uint8_t>& bytes, size_t offset, T value) {
std::memcpy(bytes.data() + offset, &value, sizeof(T));
}

template <typename T>
void write_bytes(std::string& data, size_t offset, T value) {
std::memcpy(&data[offset], &value, sizeof(T));
}
} // namespace

TEST_CASE("empty", "[tdigest]") {
tdigest_double td(10);
// std::cout << td.to_string();
Expand Down Expand Up @@ -470,4 +493,112 @@ TEST_CASE("iterate centroids", "[tdigest]") {
REQUIRE(td.get_total_weight() == total_weight);
}

TEST_CASE("update rejects positive infinity", "[tdigest]") {
tdigest_double td(100);
td.update(1.0);
td.update(2.0);
td.update(std::numeric_limits<double>::infinity());
REQUIRE(td.get_total_weight() == 2);
REQUIRE(td.get_max_value() == 2.0);
}

TEST_CASE("update rejects negative infinity", "[tdigest]") {
tdigest_double td(100);
td.update(1.0);
td.update(2.0);
td.update(-std::numeric_limits<double>::infinity());
REQUIRE(td.get_total_weight() == 2);
REQUIRE(td.get_min_value() == 1.0);
}

TEST_CASE("deserialize bytes rejects NaN single value", "[tdigest]") {
tdigest_double td(100);
td.update(1.0);
auto bytes = td.serialize();
write_bytes(bytes, single_value_offset, std::numeric_limits<double>::quiet_NaN());
REQUIRE_THROWS_AS(tdigest_double::deserialize(bytes.data(), bytes.size()), std::invalid_argument);
}

TEST_CASE("deserialize stream rejects infinity min", "[tdigest]") {
tdigest_double td(100);
td.update(1.0);
td.update(2.0);
td.update(3.0);
auto bytes = td.serialize();
std::string data(reinterpret_cast<const char*>(bytes.data()), bytes.size());
write_bytes(data, min_offset, std::numeric_limits<double>::infinity());
std::istringstream is(data, std::ios::binary);
REQUIRE_THROWS_AS(tdigest_double::deserialize(is), std::invalid_argument);
}

TEST_CASE("deserialize bytes rejects NaN centroid mean", "[tdigest]") {
tdigest_double td(100);
for (int i = 0; i < 10; ++i) td.update(i);
auto bytes = td.serialize();
write_bytes(bytes, first_centroid_mean_offset, std::numeric_limits<double>::quiet_NaN());
REQUIRE_THROWS_AS(tdigest_double::deserialize(bytes.data(), bytes.size()), std::invalid_argument);
}

TEST_CASE("deserialize bytes rejects NaN buffered value", "[tdigest]") {
tdigest_double td(100);
td.update(1.0);
td.update(2.0);
auto bytes = td.serialize(0, true);
write_bytes(bytes, first_buffered_value_offset, std::numeric_limits<double>::quiet_NaN());
REQUIRE_THROWS_AS(tdigest_double::deserialize(bytes.data(), bytes.size()), std::invalid_argument);
}

TEST_CASE("deserialize bytes rejects infinity single value", "[tdigest]") {
tdigest_double td(100);
td.update(1.0);
auto bytes = td.serialize();
write_bytes(bytes, single_value_offset, std::numeric_limits<double>::infinity());
REQUIRE_THROWS_AS(tdigest_double::deserialize(bytes.data(), bytes.size()), std::invalid_argument);
}

TEST_CASE("deserialize bytes rejects NaN max", "[tdigest]") {
tdigest_double td(100);
td.update(1.0);
td.update(2.0);
auto bytes = td.serialize();
write_bytes(bytes, max_offset, std::numeric_limits<double>::quiet_NaN());
REQUIRE_THROWS_AS(tdigest_double::deserialize(bytes.data(), bytes.size()), std::invalid_argument);
}

TEST_CASE("deserialize bytes rejects infinity max", "[tdigest]") {
tdigest_double td(100);
td.update(1.0);
td.update(2.0);
auto bytes = td.serialize();
write_bytes(bytes, max_offset, std::numeric_limits<double>::infinity());
REQUIRE_THROWS_AS(tdigest_double::deserialize(bytes.data(), bytes.size()), std::invalid_argument);
}

TEST_CASE("deserialize bytes rejects infinity buffered value", "[tdigest]") {
tdigest_double td(100);
td.update(1.0);
td.update(2.0);
auto bytes = td.serialize(0, true);
write_bytes(bytes, first_buffered_value_offset, std::numeric_limits<double>::infinity());
REQUIRE_THROWS_AS(tdigest_double::deserialize(bytes.data(), bytes.size()), std::invalid_argument);
}

TEST_CASE("deserialize bytes rejects zero centroid weight", "[tdigest]") {
tdigest_double td(100);
for (int i = 0; i < 10; ++i) td.update(i);
auto bytes = td.serialize();
write_bytes(bytes, first_centroid_weight_offset, static_cast<uint64_t>(0));
REQUIRE_THROWS_AS(tdigest_double::deserialize(bytes.data(), bytes.size()), std::invalid_argument);
}

TEST_CASE("deserialize stream rejects zero centroid weight", "[tdigest]") {
tdigest_double td(100);
for (int i = 0; i < 10; ++i) td.update(i);
auto bytes = td.serialize();
std::string data(reinterpret_cast<const char*>(bytes.data()), bytes.size());
write_bytes(data, first_centroid_weight_offset, static_cast<uint64_t>(0));
std::istringstream is(data, std::ios::binary);
REQUIRE_THROWS_AS(tdigest_double::deserialize(is), std::invalid_argument);
}

} /* namespace datasketches */