Skip to content
41 changes: 40 additions & 1 deletion rust_snuba/src/processors/eap_items.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@ use crate::config::ProcessorConfig;
use crate::processors::utils::enforce_retention;
use crate::types::{InsertBatch, ItemTypeMetrics, KafkaMessageMetadata};

/// Precision factor for sampling_factor calculations to compensate for floating point errors
const SAMPLING_FACTOR_PRECISION: f64 = 1e9;
/// Minimum allowed sampling_factor value
const MIN_SAMPLING_FACTOR: f64 = 1.0 / SAMPLING_FACTOR_PRECISION;

pub fn process_message(
msg: KafkaPayload,
_metadata: KafkaMessageMetadata,
Expand Down Expand Up @@ -119,8 +124,16 @@ impl TryFrom<TraceItem> for EAPItem {
}

// Lower precision to compensate floating point errors.
eap_item.sampling_factor = (eap_item.sampling_factor * SAMPLING_FACTOR_PRECISION).round()
/ SAMPLING_FACTOR_PRECISION;

// Ensure sampling_factor has a minimum value to prevent zero
if eap_item.sampling_factor < MIN_SAMPLING_FACTOR {
eap_item.sampling_factor = MIN_SAMPLING_FACTOR;
}

// Calculate sampling_weight after applying minimum to ensure correct value
eap_item.sampling_weight = (1.0 / eap_item.sampling_factor).round() as u64;
eap_item.sampling_factor = (eap_item.sampling_factor * 1e9).round() / 1e9;

Ok(eap_item)
}
Expand Down Expand Up @@ -462,4 +475,30 @@ mod tests {
EAPValue::Int(1234567890)
);
}

#[test]
fn test_very_low_sample_rates_do_not_result_in_zero_sampling_factor() {
let item_id = Uuid::new_v4();
let mut trace_item = generate_trace_item(item_id);

// Set extremely low sample rates that would multiply to a value smaller than 1e-9
trace_item.client_sample_rate = 0.00001; // 1e-5
trace_item.server_sample_rate = 0.00001; // 1e-5
// Combined: 1e-10, which is smaller than MIN_SAMPLING_FACTOR (1e-9)

let eap_item = EAPItem::try_from(trace_item);

assert!(eap_item.is_ok());
let eap_item = eap_item.unwrap();

// Verify that sampling_factor is not zero and equals the minimum
assert_eq!(eap_item.sampling_factor, MIN_SAMPLING_FACTOR);
assert!(eap_item.sampling_factor > 0.0);

// Verify that sampling_weight is calculated correctly
assert_eq!(
eap_item.sampling_weight,
(1.0 / MIN_SAMPLING_FACTOR).round() as u64
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1035,115 +1035,3 @@ def test_formula_reliability_nested(self) -> None:
),
]
assert actual == expected

def test_aggregation_with_zero_sampling_rate_mixed(self) -> None:
"""
Test that aggregations correctly discard items with zero sampling rate.
Items with zero sampling rate should be excluded from extrapolation calculations.
"""
granularity_secs = 120
query_duration = 3600

# Store metrics with zero sampling rate (should be discarded)
store_timeseries(
BASE_TIME,
10, # every 10 seconds
3600,
metrics=[
DummyMetric("test_metric", get_value=lambda x: 1000)
], # Large value to make it obvious
server_sample_rate=0.0000000001, # Zero sampling rate - should be discarded
)

# Store metrics with non-zero sampling rate (should be included)
store_timeseries(
BASE_TIME,
1, # every second
3600,
metrics=[DummyMetric("test_metric", get_value=lambda x: 10)],
server_sample_rate=1.0, # 100% sampling rate
)
message = TimeSeriesRequest(
meta=RequestMeta(
project_ids=[1, 2, 3],
organization_id=1,
cogs_category="something",
referrer="something",
start_timestamp=Timestamp(seconds=int(BASE_TIME.timestamp())),
end_timestamp=Timestamp(seconds=int(BASE_TIME.timestamp() + query_duration)),
trace_item_type=TraceItemType.TRACE_ITEM_TYPE_SPAN,
),
aggregations=[
AttributeAggregation(
aggregate=Function.FUNCTION_SUM,
key=AttributeKey(type=AttributeKey.TYPE_FLOAT, name="test_metric"),
label="sum(test_metric)",
extrapolation_mode=ExtrapolationMode.EXTRAPOLATION_MODE_SAMPLE_WEIGHTED,
),
AttributeAggregation(
aggregate=Function.FUNCTION_COUNT,
key=AttributeKey(type=AttributeKey.TYPE_FLOAT, name="test_metric"),
label="count(test_metric)",
extrapolation_mode=ExtrapolationMode.EXTRAPOLATION_MODE_SAMPLE_WEIGHTED,
),
AttributeAggregation(
aggregate=Function.FUNCTION_AVG,
key=AttributeKey(type=AttributeKey.TYPE_FLOAT, name="test_metric"),
label="avg(test_metric)",
extrapolation_mode=ExtrapolationMode.EXTRAPOLATION_MODE_SAMPLE_WEIGHTED,
),
],
granularity_secs=granularity_secs,
)
response = EndpointTimeSeries().execute(message)
expected_buckets = [
Timestamp(seconds=int(BASE_TIME.timestamp()) + secs)
for secs in range(0, query_duration, granularity_secs)
]

# Zero sample rate items should be discarded
# Only the 120 items per bucket (1/sec * 120sec) with value 10 should be counted
assert sorted(response.result_timeseries, key=lambda x: x.label) == [
TimeSeries(
label="avg(test_metric)",
buckets=expected_buckets,
data_points=[
DataPoint(
data=10, # Average should be 10, not influenced by zero-sample items
data_present=True,
reliability=Reliability.RELIABILITY_HIGH,
avg_sampling_rate=1.0,
sample_count=120,
)
for _ in range(len(expected_buckets))
],
),
TimeSeries(
label="count(test_metric)",
buckets=expected_buckets,
data_points=[
DataPoint(
data=120, # Count should be 120, not including zero-sample items
data_present=True,
reliability=Reliability.RELIABILITY_HIGH,
avg_sampling_rate=1.0,
sample_count=120,
)
for _ in range(len(expected_buckets))
],
),
TimeSeries(
label="sum(test_metric)",
buckets=expected_buckets,
data_points=[
DataPoint(
data=120 * 10, # Sum should be 1200, not influenced by zero-sample items
data_present=True,
reliability=Reliability.RELIABILITY_HIGH,
avg_sampling_rate=1.0,
sample_count=120,
)
for _ in range(len(expected_buckets))
],
),
]
Original file line number Diff line number Diff line change
Expand Up @@ -1187,106 +1187,3 @@ def test_aggregation_with_nulls(self) -> None:
],
),
]

def test_aggregation_with_zero_sampling_rate_mixed(self) -> None:
"""
Test that aggregations correctly discard items with zero sampling rate.
Items with zero sampling rate should be excluded from extrapolation calculations.
"""
items_storage = get_storage(StorageKey("eap_items"))
messages_with_zero_rate = []
messages_with_nonzero_rate = []

# Create 3 items with zero sampling rate (should be discarded)
for i in range(3):
start_timestamp = BASE_TIME - timedelta(minutes=i + 1)
end_timestamp = start_timestamp + timedelta(seconds=1)
messages_with_zero_rate.append(
gen_item_message(
start_timestamp=start_timestamp,
attributes={
"custom_measurement": AnyValue(
int_value=100
), # Large value to make it obvious if included
"custom_tag": AnyValue(string_value="test"),
},
server_sample_rate=0.0000000001, # Sampling rate so low it gets clipped to 0, should be discarded
end_timestamp=end_timestamp,
)
)

# Create 5 items with non-zero sampling rate (should be included)
for i in range(5):
start_timestamp = BASE_TIME - timedelta(minutes=i + 1)
end_timestamp = start_timestamp + timedelta(seconds=1)
messages_with_nonzero_rate.append(
gen_item_message(
start_timestamp=start_timestamp,
attributes={
"custom_measurement": AnyValue(int_value=i), # Values 0, 1, 2, 3, 4
"custom_tag": AnyValue(string_value="test"),
},
server_sample_rate=1.0, # Full sampling rate
end_timestamp=end_timestamp,
)
)

write_raw_unprocessed_events(
items_storage, # type: ignore
messages_with_zero_rate + messages_with_nonzero_rate,
)

ts = Timestamp(seconds=int(BASE_TIME.timestamp()))
hour_ago = int((BASE_TIME - timedelta(hours=1)).timestamp())
message = TraceItemTableRequest(
meta=RequestMeta(
project_ids=[1],
organization_id=1,
cogs_category="something",
referrer="something",
start_timestamp=Timestamp(seconds=hour_ago),
end_timestamp=ts,
trace_item_type=TraceItemType.TRACE_ITEM_TYPE_SPAN,
),
columns=[
Column(
aggregation=AttributeAggregation(
aggregate=Function.FUNCTION_SUM,
key=AttributeKey(type=AttributeKey.TYPE_DOUBLE, name="custom_measurement"),
label="sum(custom_measurement)",
extrapolation_mode=ExtrapolationMode.EXTRAPOLATION_MODE_SAMPLE_WEIGHTED,
)
),
Column(
aggregation=AttributeAggregation(
aggregate=Function.FUNCTION_COUNT,
key=AttributeKey(type=AttributeKey.TYPE_DOUBLE, name="custom_measurement"),
label="count(custom_measurement)",
extrapolation_mode=ExtrapolationMode.EXTRAPOLATION_MODE_SAMPLE_WEIGHTED,
)
),
Column(
aggregation=AttributeAggregation(
aggregate=Function.FUNCTION_AVG,
key=AttributeKey(type=AttributeKey.TYPE_DOUBLE, name="custom_measurement"),
label="avg(custom_measurement)",
extrapolation_mode=ExtrapolationMode.EXTRAPOLATION_MODE_SAMPLE_WEIGHTED,
)
),
],
order_by=[],
limit=5,
)
response = EndpointTraceItemTable().execute(message)

# Zero sample rate items should be discarded, so we should only see results from the 5 non-zero items
measurement_sum = [v.val_double for v in response.column_values[0].results][0]
measurement_count = [v.val_double for v in response.column_values[1].results][0]
measurement_avg = [v.val_double for v in response.column_values[2].results][0]

# Expected: sum of 0+1+2+3+4 = 10 (zero sample rate items with value 100 should be excluded)
assert measurement_sum == 10
# Expected: count of 5 items (zero sample rate items should be excluded)
assert measurement_count == 5
# Expected: average of (0+1+2+3+4)/5 = 2.0
assert measurement_avg == 2.0
Loading