From 23a865d66dd25908091d81966083e6eb919c5eeb Mon Sep 17 00:00:00 2001 From: Manisha Johnson Date: Mon, 15 Apr 2024 15:09:47 +0530 Subject: [PATCH 1/4] Include delay in asr clients --- riva/clients/asr/riva_streaming_asr_client.cc | 6 ++-- .../clients/asr/streaming_recognize_client.cc | 30 +++++++++++++++---- riva/clients/asr/streaming_recognize_client.h | 7 +++-- 3 files changed, 33 insertions(+), 10 deletions(-) diff --git a/riva/clients/asr/riva_streaming_asr_client.cc b/riva/clients/asr/riva_streaming_asr_client.cc index 938b47d..62653f9 100644 --- a/riva/clients/asr/riva_streaming_asr_client.cc +++ b/riva/clients/asr/riva_streaming_asr_client.cc @@ -76,6 +76,7 @@ DEFINE_bool( "Whether to use SSL credentials or not. If ssl_cert is specified, " "this is assumed to be true"); DEFINE_string(metadata, "", "Comma separated key-value pair(s) of metadata to be sent to server"); +DEFINE_int32(async_delay_ms, 0, "Delay to start parallel request asynchronously in milliseconds"); void signal_handler(int signal_num) @@ -118,6 +119,7 @@ main(int argc, char** argv) str_usage << " --boosted_words_score=" << std::endl; str_usage << " --ssl_cert=" << std::endl; str_usage << " --metadata=" << std::endl; + str_usage << " --async_delay_ms=" << std::endl; gflags::SetUsageMessage(str_usage.str()); gflags::SetVersionString(::riva::utils::kBuildScmRevision); @@ -164,7 +166,7 @@ main(int argc, char** argv) FLAGS_profanity_filter, FLAGS_word_time_offsets, FLAGS_automatic_punctuation, /* separate_recognition_per_channel*/ false, FLAGS_print_transcripts, FLAGS_chunk_duration_ms, FLAGS_interim_results, FLAGS_output_filename, FLAGS_model_name, FLAGS_simulate_realtime, - FLAGS_verbatim_transcripts, FLAGS_boosted_words_file, FLAGS_boosted_words_score); + FLAGS_verbatim_transcripts, FLAGS_boosted_words_file, FLAGS_boosted_words_score, FLAGS_async_delay_ms); if (FLAGS_audio_file.size()) { return recognize_client.DoStreamingFromFile( @@ -205,4 +207,4 @@ main(int argc, char** argv) } return 0; -} +} \ No newline at end of file diff --git a/riva/clients/asr/streaming_recognize_client.cc b/riva/clients/asr/streaming_recognize_client.cc index ef5299a..03d36c3 100644 --- a/riva/clients/asr/streaming_recognize_client.cc +++ b/riva/clients/asr/streaming_recognize_client.cc @@ -57,7 +57,8 @@ StreamingRecognizeClient::StreamingRecognizeClient( bool word_time_offsets, bool automatic_punctuation, bool separate_recognition_per_channel, bool print_transcripts, int32_t chunk_duration_ms, bool interim_results, std::string output_filename, std::string model_name, bool simulate_realtime, - bool verbatim_transcripts, const std::string& boosted_phrases_file, float boosted_phrases_score) + bool verbatim_transcripts, const std::string& boosted_phrases_file, + float boosted_phrases_score, int32_t async_delay_ms) : print_latency_stats_(true), stub_(nr_asr::RivaSpeechRecognition::NewStub(channel)), language_code_(language_code), max_alternatives_(max_alternatives), profanity_filter_(profanity_filter), word_time_offsets_(word_time_offsets), @@ -66,7 +67,8 @@ StreamingRecognizeClient::StreamingRecognizeClient( print_transcripts_(print_transcripts), chunk_duration_ms_(chunk_duration_ms), interim_results_(interim_results), total_audio_processed_(0.), num_streams_started_(0), model_name_(model_name), simulate_realtime_(simulate_realtime), - verbatim_transcripts_(verbatim_transcripts), boosted_phrases_score_(boosted_phrases_score) + verbatim_transcripts_(verbatim_transcripts), boosted_phrases_score_(boosted_phrases_score), + async_delay_ms_(async_delay_ms) { num_active_streams_.store(0); num_streams_finished_.store(0); @@ -195,6 +197,7 @@ int StreamingRecognizeClient::DoStreamingFromFile( std::string& audio_file, int32_t num_iterations, int32_t num_parallel_requests) { + std::cout << "check async_delay_ms is : " << async_delay_ms_; // Preload all wav files, sort by size to reduce tail effects std::vector> all_wav; try { @@ -218,17 +221,32 @@ StreamingRecognizeClient::DoStreamingFromFile( } } + ; + // Ensure there's also num_parallel_requests in flight uint32_t all_wav_i = 0; auto start_time = std::chrono::steady_clock::now(); + + uint32_t initial_streams = 0; + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution<> dis(1, async_delay_ms_); + while (true) { while (NumActiveStreams() < (uint32_t)num_parallel_requests && all_wav_i < all_wav_max) { + if(async_delay_ms_>0){ + if(initial_streams < (uint32_t)num_parallel_requests) { + std::this_thread::sleep_for(std::chrono::milliseconds(dis(gen))); + initial_streams++; + } + } + std::unique_ptr stream(new Stream(all_wav_repeated[all_wav_i], all_wav_i)); StartNewStream(std::move(stream)); ++all_wav_i; } - - // Break if no more tasks to add + + // Break if no more tasks to add if (NumStreamsFinished() == all_wav_max) { break; } @@ -424,6 +442,8 @@ StreamingRecognizeClient::PrintLatencies(std::vector& latencies, const s std::cout << "\t\tMedian\t\t90th\t\t95th\t\t99th\t\tAvg\n"; std::cout << "\t\t" << median << "\t\t" << lat_90 << "\t\t" << lat_95 << "\t\t" << lat_99 << "\t\t" << avg << std::endl; + + //std::cout << "MKJ-check build"; } } @@ -444,4 +464,4 @@ StreamingRecognizeClient::PrintStats() << std::endl; return 1; } -} +} \ No newline at end of file diff --git a/riva/clients/asr/streaming_recognize_client.h b/riva/clients/asr/streaming_recognize_client.h index 14e7d17..b7c5d5a 100644 --- a/riva/clients/asr/streaming_recognize_client.h +++ b/riva/clients/asr/streaming_recognize_client.h @@ -25,7 +25,7 @@ #include #include #include - +#include #include "client_call.h" #include "riva/proto/riva_asr.grpc.pb.h" #include "riva/utils/thread_pool.h" @@ -47,7 +47,7 @@ class StreamingRecognizeClient { bool print_transcripts, int32_t chunk_duration_ms, bool interim_results, std::string output_filename, std::string model_name, bool simulate_realtime, bool verbatim_transcripts, const std::string& boosted_phrases_file, - float boosted_phrases_score); + float boosted_phrases_score, int32_t async_delay_ms); ~StreamingRecognizeClient(); @@ -114,4 +114,5 @@ class StreamingRecognizeClient { std::vector boosted_phrases_; float boosted_phrases_score_; -}; + int32_t async_delay_ms_; +}; \ No newline at end of file From beac208df8d3468189b46a7d1a28208de7d3aadb Mon Sep 17 00:00:00 2001 From: Manisha Johnson Date: Mon, 15 Apr 2024 15:58:45 +0530 Subject: [PATCH 2/4] include delay in asr clients --- riva/clients/asr/streaming_recognize_client.cc | 2 -- 1 file changed, 2 deletions(-) diff --git a/riva/clients/asr/streaming_recognize_client.cc b/riva/clients/asr/streaming_recognize_client.cc index 03d36c3..b513ded 100644 --- a/riva/clients/asr/streaming_recognize_client.cc +++ b/riva/clients/asr/streaming_recognize_client.cc @@ -442,8 +442,6 @@ StreamingRecognizeClient::PrintLatencies(std::vector& latencies, const s std::cout << "\t\tMedian\t\t90th\t\t95th\t\t99th\t\tAvg\n"; std::cout << "\t\t" << median << "\t\t" << lat_90 << "\t\t" << lat_95 << "\t\t" << lat_99 << "\t\t" << avg << std::endl; - - //std::cout << "MKJ-check build"; } } From 916a6993c1123f41de58fbe3929b56d20487ed4a Mon Sep 17 00:00:00 2001 From: Manisha J <167071960+manishaj-nv@users.noreply.github.com> Date: Mon, 15 Apr 2024 19:28:52 +0530 Subject: [PATCH 3/4] Update unit test --- riva/clients/asr/streaming_recognize_client_test.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/riva/clients/asr/streaming_recognize_client_test.cc b/riva/clients/asr/streaming_recognize_client_test.cc index 38aa185..51ecc3f 100644 --- a/riva/clients/asr/streaming_recognize_client_test.cc +++ b/riva/clients/asr/streaming_recognize_client_test.cc @@ -20,7 +20,7 @@ TEST(StreamingRecognizeClient, num_responses_requests) StreamingRecognizeClient recognize_client( grpc_channel, 1, "en-US", 1, false, false, false, false, false, 800, false, "dummy.txt", - "dummy", true, true, "", 10.); + "dummy", true, true, "", 10., 100); std::shared_ptr call = std::make_shared(1, true); uint32_t num_sends = 10; From 8e0dc7e20ec2624cb2a2b6946968cfac78117769 Mon Sep 17 00:00:00 2001 From: Manisha Johnson Date: Wed, 17 Apr 2024 11:47:50 +0530 Subject: [PATCH 4/4] include delay in asr clients --- riva/clients/asr/riva_streaming_asr_client.cc | 3 ++- riva/clients/asr/streaming_recognize_client.cc | 1 - 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/riva/clients/asr/riva_streaming_asr_client.cc b/riva/clients/asr/riva_streaming_asr_client.cc index 62653f9..4815640 100644 --- a/riva/clients/asr/riva_streaming_asr_client.cc +++ b/riva/clients/asr/riva_streaming_asr_client.cc @@ -166,7 +166,8 @@ main(int argc, char** argv) FLAGS_profanity_filter, FLAGS_word_time_offsets, FLAGS_automatic_punctuation, /* separate_recognition_per_channel*/ false, FLAGS_print_transcripts, FLAGS_chunk_duration_ms, FLAGS_interim_results, FLAGS_output_filename, FLAGS_model_name, FLAGS_simulate_realtime, - FLAGS_verbatim_transcripts, FLAGS_boosted_words_file, FLAGS_boosted_words_score, FLAGS_async_delay_ms); + FLAGS_verbatim_transcripts, FLAGS_boosted_words_file, FLAGS_boosted_words_score, + FLAGS_async_delay_ms); if (FLAGS_audio_file.size()) { return recognize_client.DoStreamingFromFile( diff --git a/riva/clients/asr/streaming_recognize_client.cc b/riva/clients/asr/streaming_recognize_client.cc index b513ded..3117827 100644 --- a/riva/clients/asr/streaming_recognize_client.cc +++ b/riva/clients/asr/streaming_recognize_client.cc @@ -197,7 +197,6 @@ int StreamingRecognizeClient::DoStreamingFromFile( std::string& audio_file, int32_t num_iterations, int32_t num_parallel_requests) { - std::cout << "check async_delay_ms is : " << async_delay_ms_; // Preload all wav files, sort by size to reduce tail effects std::vector> all_wav; try {