diff --git a/rust_snuba/src/processors/eap_items.rs b/rust_snuba/src/processors/eap_items.rs index 9931e754c4..c313af1c19 100644 --- a/rust_snuba/src/processors/eap_items.rs +++ b/rust_snuba/src/processors/eap_items.rs @@ -15,6 +15,11 @@ use crate::config::ProcessorConfig; use crate::processors::utils::enforce_retention; use crate::types::{InsertBatch, ItemTypeMetrics, KafkaMessageMetadata}; +/// Precision factor for sampling_factor calculations to compensate for floating point errors +const SAMPLING_FACTOR_PRECISION: f64 = 1e9; +/// Minimum allowed sampling_factor value +const MIN_SAMPLING_FACTOR: f64 = 1.0 / SAMPLING_FACTOR_PRECISION; + pub fn process_message( msg: KafkaPayload, _metadata: KafkaMessageMetadata, @@ -119,8 +124,16 @@ impl TryFrom for EAPItem { } // Lower precision to compensate floating point errors. + eap_item.sampling_factor = (eap_item.sampling_factor * SAMPLING_FACTOR_PRECISION).round() + / SAMPLING_FACTOR_PRECISION; + + // Ensure sampling_factor has a minimum value to prevent zero + if eap_item.sampling_factor < MIN_SAMPLING_FACTOR { + eap_item.sampling_factor = MIN_SAMPLING_FACTOR; + } + + // Calculate sampling_weight after applying minimum to ensure correct value eap_item.sampling_weight = (1.0 / eap_item.sampling_factor).round() as u64; - eap_item.sampling_factor = (eap_item.sampling_factor * 1e9).round() / 1e9; Ok(eap_item) } @@ -462,4 +475,30 @@ mod tests { EAPValue::Int(1234567890) ); } + + #[test] + fn test_very_low_sample_rates_do_not_result_in_zero_sampling_factor() { + let item_id = Uuid::new_v4(); + let mut trace_item = generate_trace_item(item_id); + + // Set extremely low sample rates that would multiply to a value smaller than 1e-9 + trace_item.client_sample_rate = 0.00001; // 1e-5 + trace_item.server_sample_rate = 0.00001; // 1e-5 + // Combined: 1e-10, which is smaller than MIN_SAMPLING_FACTOR (1e-9) + + let eap_item = EAPItem::try_from(trace_item); + + assert!(eap_item.is_ok()); + let eap_item = eap_item.unwrap(); + + // Verify that sampling_factor is not zero and equals the minimum + assert_eq!(eap_item.sampling_factor, MIN_SAMPLING_FACTOR); + assert!(eap_item.sampling_factor > 0.0); + + // Verify that sampling_weight is calculated correctly + assert_eq!( + eap_item.sampling_weight, + (1.0 / MIN_SAMPLING_FACTOR).round() as u64 + ); + } } diff --git a/tests/web/rpc/v1/test_endpoint_time_series/test_endpoint_time_series_extrapolation.py b/tests/web/rpc/v1/test_endpoint_time_series/test_endpoint_time_series_extrapolation.py index a35c436674..d3fce8472d 100644 --- a/tests/web/rpc/v1/test_endpoint_time_series/test_endpoint_time_series_extrapolation.py +++ b/tests/web/rpc/v1/test_endpoint_time_series/test_endpoint_time_series_extrapolation.py @@ -1035,115 +1035,3 @@ def test_formula_reliability_nested(self) -> None: ), ] assert actual == expected - - def test_aggregation_with_zero_sampling_rate_mixed(self) -> None: - """ - Test that aggregations correctly discard items with zero sampling rate. - Items with zero sampling rate should be excluded from extrapolation calculations. - """ - granularity_secs = 120 - query_duration = 3600 - - # Store metrics with zero sampling rate (should be discarded) - store_timeseries( - BASE_TIME, - 10, # every 10 seconds - 3600, - metrics=[ - DummyMetric("test_metric", get_value=lambda x: 1000) - ], # Large value to make it obvious - server_sample_rate=0.0000000001, # Zero sampling rate - should be discarded - ) - - # Store metrics with non-zero sampling rate (should be included) - store_timeseries( - BASE_TIME, - 1, # every second - 3600, - metrics=[DummyMetric("test_metric", get_value=lambda x: 10)], - server_sample_rate=1.0, # 100% sampling rate - ) - message = TimeSeriesRequest( - meta=RequestMeta( - project_ids=[1, 2, 3], - organization_id=1, - cogs_category="something", - referrer="something", - start_timestamp=Timestamp(seconds=int(BASE_TIME.timestamp())), - end_timestamp=Timestamp(seconds=int(BASE_TIME.timestamp() + query_duration)), - trace_item_type=TraceItemType.TRACE_ITEM_TYPE_SPAN, - ), - aggregations=[ - AttributeAggregation( - aggregate=Function.FUNCTION_SUM, - key=AttributeKey(type=AttributeKey.TYPE_FLOAT, name="test_metric"), - label="sum(test_metric)", - extrapolation_mode=ExtrapolationMode.EXTRAPOLATION_MODE_SAMPLE_WEIGHTED, - ), - AttributeAggregation( - aggregate=Function.FUNCTION_COUNT, - key=AttributeKey(type=AttributeKey.TYPE_FLOAT, name="test_metric"), - label="count(test_metric)", - extrapolation_mode=ExtrapolationMode.EXTRAPOLATION_MODE_SAMPLE_WEIGHTED, - ), - AttributeAggregation( - aggregate=Function.FUNCTION_AVG, - key=AttributeKey(type=AttributeKey.TYPE_FLOAT, name="test_metric"), - label="avg(test_metric)", - extrapolation_mode=ExtrapolationMode.EXTRAPOLATION_MODE_SAMPLE_WEIGHTED, - ), - ], - granularity_secs=granularity_secs, - ) - response = EndpointTimeSeries().execute(message) - expected_buckets = [ - Timestamp(seconds=int(BASE_TIME.timestamp()) + secs) - for secs in range(0, query_duration, granularity_secs) - ] - - # Zero sample rate items should be discarded - # Only the 120 items per bucket (1/sec * 120sec) with value 10 should be counted - assert sorted(response.result_timeseries, key=lambda x: x.label) == [ - TimeSeries( - label="avg(test_metric)", - buckets=expected_buckets, - data_points=[ - DataPoint( - data=10, # Average should be 10, not influenced by zero-sample items - data_present=True, - reliability=Reliability.RELIABILITY_HIGH, - avg_sampling_rate=1.0, - sample_count=120, - ) - for _ in range(len(expected_buckets)) - ], - ), - TimeSeries( - label="count(test_metric)", - buckets=expected_buckets, - data_points=[ - DataPoint( - data=120, # Count should be 120, not including zero-sample items - data_present=True, - reliability=Reliability.RELIABILITY_HIGH, - avg_sampling_rate=1.0, - sample_count=120, - ) - for _ in range(len(expected_buckets)) - ], - ), - TimeSeries( - label="sum(test_metric)", - buckets=expected_buckets, - data_points=[ - DataPoint( - data=120 * 10, # Sum should be 1200, not influenced by zero-sample items - data_present=True, - reliability=Reliability.RELIABILITY_HIGH, - avg_sampling_rate=1.0, - sample_count=120, - ) - for _ in range(len(expected_buckets)) - ], - ), - ] diff --git a/tests/web/rpc/v1/test_endpoint_trace_item_table/test_endpoint_trace_item_table_extrapolation.py b/tests/web/rpc/v1/test_endpoint_trace_item_table/test_endpoint_trace_item_table_extrapolation.py index 36d360c912..c0a4544d20 100644 --- a/tests/web/rpc/v1/test_endpoint_trace_item_table/test_endpoint_trace_item_table_extrapolation.py +++ b/tests/web/rpc/v1/test_endpoint_trace_item_table/test_endpoint_trace_item_table_extrapolation.py @@ -1187,106 +1187,3 @@ def test_aggregation_with_nulls(self) -> None: ], ), ] - - def test_aggregation_with_zero_sampling_rate_mixed(self) -> None: - """ - Test that aggregations correctly discard items with zero sampling rate. - Items with zero sampling rate should be excluded from extrapolation calculations. - """ - items_storage = get_storage(StorageKey("eap_items")) - messages_with_zero_rate = [] - messages_with_nonzero_rate = [] - - # Create 3 items with zero sampling rate (should be discarded) - for i in range(3): - start_timestamp = BASE_TIME - timedelta(minutes=i + 1) - end_timestamp = start_timestamp + timedelta(seconds=1) - messages_with_zero_rate.append( - gen_item_message( - start_timestamp=start_timestamp, - attributes={ - "custom_measurement": AnyValue( - int_value=100 - ), # Large value to make it obvious if included - "custom_tag": AnyValue(string_value="test"), - }, - server_sample_rate=0.0000000001, # Sampling rate so low it gets clipped to 0, should be discarded - end_timestamp=end_timestamp, - ) - ) - - # Create 5 items with non-zero sampling rate (should be included) - for i in range(5): - start_timestamp = BASE_TIME - timedelta(minutes=i + 1) - end_timestamp = start_timestamp + timedelta(seconds=1) - messages_with_nonzero_rate.append( - gen_item_message( - start_timestamp=start_timestamp, - attributes={ - "custom_measurement": AnyValue(int_value=i), # Values 0, 1, 2, 3, 4 - "custom_tag": AnyValue(string_value="test"), - }, - server_sample_rate=1.0, # Full sampling rate - end_timestamp=end_timestamp, - ) - ) - - write_raw_unprocessed_events( - items_storage, # type: ignore - messages_with_zero_rate + messages_with_nonzero_rate, - ) - - ts = Timestamp(seconds=int(BASE_TIME.timestamp())) - hour_ago = int((BASE_TIME - timedelta(hours=1)).timestamp()) - message = TraceItemTableRequest( - meta=RequestMeta( - project_ids=[1], - organization_id=1, - cogs_category="something", - referrer="something", - start_timestamp=Timestamp(seconds=hour_ago), - end_timestamp=ts, - trace_item_type=TraceItemType.TRACE_ITEM_TYPE_SPAN, - ), - columns=[ - Column( - aggregation=AttributeAggregation( - aggregate=Function.FUNCTION_SUM, - key=AttributeKey(type=AttributeKey.TYPE_DOUBLE, name="custom_measurement"), - label="sum(custom_measurement)", - extrapolation_mode=ExtrapolationMode.EXTRAPOLATION_MODE_SAMPLE_WEIGHTED, - ) - ), - Column( - aggregation=AttributeAggregation( - aggregate=Function.FUNCTION_COUNT, - key=AttributeKey(type=AttributeKey.TYPE_DOUBLE, name="custom_measurement"), - label="count(custom_measurement)", - extrapolation_mode=ExtrapolationMode.EXTRAPOLATION_MODE_SAMPLE_WEIGHTED, - ) - ), - Column( - aggregation=AttributeAggregation( - aggregate=Function.FUNCTION_AVG, - key=AttributeKey(type=AttributeKey.TYPE_DOUBLE, name="custom_measurement"), - label="avg(custom_measurement)", - extrapolation_mode=ExtrapolationMode.EXTRAPOLATION_MODE_SAMPLE_WEIGHTED, - ) - ), - ], - order_by=[], - limit=5, - ) - response = EndpointTraceItemTable().execute(message) - - # Zero sample rate items should be discarded, so we should only see results from the 5 non-zero items - measurement_sum = [v.val_double for v in response.column_values[0].results][0] - measurement_count = [v.val_double for v in response.column_values[1].results][0] - measurement_avg = [v.val_double for v in response.column_values[2].results][0] - - # Expected: sum of 0+1+2+3+4 = 10 (zero sample rate items with value 100 should be excluded) - assert measurement_sum == 10 - # Expected: count of 5 items (zero sample rate items should be excluded) - assert measurement_count == 5 - # Expected: average of (0+1+2+3+4)/5 = 2.0 - assert measurement_avg == 2.0