Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 90 additions & 6 deletions riva/clients/tts/riva_tts_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ namespace nr = nvidia::riva;
namespace nr_tts = nvidia::riva::tts;

DEFINE_string(text, "", "Text to be synthesized");
DEFINE_string(
text_file, "", "Text file with list of sentences to be synthesized. Ignored if 'text' is set.");
DEFINE_string(audio_file, "output.wav", "Output file");
DEFINE_string(audio_encoding, "pcm", "Audio encoding (pcm or opus)");
DEFINE_string(riva_uri, "localhost:50051", "Riva API server URI and port");
Expand All @@ -37,6 +39,7 @@ DEFINE_string(ssl_client_key, "", "Path to SSL client certificates key");
DEFINE_string(ssl_client_cert, "", "Path to SSL client certificates file");
DEFINE_int32(rate, 44100, "Sample rate for the TTS output");
DEFINE_bool(online, false, "Whether synthesis should be online or batch");
DEFINE_bool(streaming, false, "Whether synthesis should be streaming or batch");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it possible to use existing online flag for this? separate flags for input/output can be confusing for user. Do we see a case where we want to support non-streaming input and streaming output?

DEFINE_string(
language, "en-US",
"Language code as per [BCP-47](https://www.rfc-editor.org/rfc/bcp/bcp47.txt) language tag.");
Expand Down Expand Up @@ -101,13 +104,15 @@ main(int argc, char** argv)
std::stringstream str_usage;
str_usage << "Usage: riva_tts_client " << std::endl;
str_usage << " --text=<text> " << std::endl;
str_usage << " --text_file=<filename> " << std::endl;
str_usage << " --audio_file=<filename> " << std::endl;
str_usage << " --audio_encoding=<pcm|opus> " << std::endl;
str_usage << " --riva_uri=<server_name:port> " << std::endl;
str_usage << " --rate=<sample_rate> " << std::endl;
str_usage << " --language=<language-code> " << std::endl;
str_usage << " --voice_name=<voice-name> " << std::endl;
str_usage << " --online=<true|false> " << std::endl;
str_usage << " --streaming=<true|false> " << std::endl;
str_usage << " --ssl_root_cert=<filename>" << std::endl;
str_usage << " --ssl_client_key=<filename>" << std::endl;
str_usage << " --ssl_client_cert=<filename>" << std::endl;
Expand All @@ -134,10 +139,26 @@ main(int argc, char** argv)
}

auto text = FLAGS_text;
if (text.length() == 0) {
LOG(ERROR) << "Input text cannot be empty." << std::endl;
auto text_file = FLAGS_text_file;
std::vector<std::string> text_lines;
if (text.length() == 0 && text_file.length() == 0) {
LOG(ERROR) << "Input text or text file cannot be empty." << std::endl;
return -1;
}
if (text.length() > 0 && text_file.length() > 0) {
LOG(ERROR) << "Only one of text or text file can be provided." << std::endl;
return -1;
}
if (text_file.length() > 0) {
std::ifstream infile(text_file);
if (infile.is_open()) {
std::string line;
while (std::getline(infile, line)) {
text_lines.push_back(line);
text += line + " ";
}
}
}

bool flag_set = gflags::GetCommandLineFlagInfoOrDie("riva_uri").is_default;
const char* riva_uri = getenv("RIVA_URI");
Expand All @@ -152,7 +173,8 @@ main(int argc, char** argv)
auto creds = riva::clients::CreateChannelCredentials(
FLAGS_use_ssl, FLAGS_ssl_root_cert, FLAGS_ssl_client_key, FLAGS_ssl_client_cert,
FLAGS_metadata);
grpc_channel = riva::clients::CreateChannelBlocking(FLAGS_riva_uri, creds, FLAGS_timeout_ms, FLAGS_max_grpc_message_size);
grpc_channel = riva::clients::CreateChannelBlocking(
FLAGS_riva_uri, creds, FLAGS_timeout_ms, FLAGS_max_grpc_message_size);
}
catch (const std::exception& e) {
std::cerr << "Error creating GRPC channel: " << e.what() << std::endl;
Expand Down Expand Up @@ -251,7 +273,7 @@ main(int argc, char** argv)
decoder.DeserializeOpus(std::vector<unsigned char>(ptr, ptr + audio.size())));
::riva::utils::wav::Write(FLAGS_audio_file, rate, pcm.data(), pcm.size());
}
} else { // online inference
} else if (FLAGS_online && not FLAGS_streaming) { // batch inference
if (not FLAGS_zero_shot_transcript.empty()) {
LOG(ERROR) << "Zero shot transcript is not supported for streaming inference.";
return -1;
Expand All @@ -261,8 +283,11 @@ main(int argc, char** argv)
size_t audio_len = 0;
nr_tts::SynthesizeSpeechResponse chunk;
auto start = std::chrono::steady_clock::now();
std::unique_ptr<grpc::ClientReader<nr_tts::SynthesizeSpeechResponse>> reader(
tts->SynthesizeOnline(&context, request));
std::unique_ptr<
grpc::ClientReaderWriter<nr_tts::SynthesizeSpeechRequest, nr_tts::SynthesizeSpeechResponse>>
reader(tts->SynthesizeOnline(&context));
reader->Write(request);
reader->WritesDone();
while (reader->Read(&chunk)) {
// Copy chunk to local buffer
if (audio_len == 0) {
Expand Down Expand Up @@ -295,6 +320,65 @@ main(int argc, char** argv)
return -1;
}

if (FLAGS_audio_encoding.empty() || FLAGS_audio_encoding == "pcm") {
::riva::utils::wav::Write(FLAGS_audio_file, rate, pcm_buffer.data(), pcm_buffer.size());
} else if (FLAGS_audio_encoding == "opus") {
riva::utils::opus::Decoder decoder(rate, 1);
auto pcm = decoder.DecodePcm(decoder.DeserializeOpus(opus_buffer));
::riva::utils::wav::Write(FLAGS_audio_file, rate, pcm.data(), pcm.size());
}
} else if (FLAGS_online && FLAGS_streaming) { // streaming inference

std::vector<int16_t> pcm_buffer;
std::vector<unsigned char> opus_buffer;
size_t audio_len = 0;
nr_tts::SynthesizeSpeechResponse chunk;
auto start = std::chrono::steady_clock::now();
std::unique_ptr<
grpc::ClientReaderWriter<nr_tts::SynthesizeSpeechRequest, nr_tts::SynthesizeSpeechResponse>>
reader(tts->SynthesizeOnline(&context));
for (const auto& line : text_lines) {
if (line.find("|") != std::string::npos) {
request.set_text(line.substr(line.find("|") + 1, line.length()));
} else {
request.set_text(line);
}
reader->Write(request);
}
reader->WritesDone();
while (reader->Read(&chunk)) {
// Copy chunk to local buffer
if (audio_len == 0) {
auto t_first_audio = std::chrono::steady_clock::now();
std::chrono::duration<double> elapsed_first_audio = t_first_audio - start;
LOG(INFO) << "Time to first chunk: " << elapsed_first_audio.count() << " s" << std::endl;
}
LOG(INFO) << "Got chunk: " << chunk.audio().size() << " bytes" << std::endl;
if (FLAGS_audio_encoding.empty() || FLAGS_audio_encoding == "pcm") {
int16_t* audio_data = (int16_t*)chunk.audio().data();
size_t len = chunk.audio().length() / sizeof(int16_t);
std::copy(audio_data, audio_data + len, std::back_inserter(pcm_buffer));
audio_len += len;
} else if (FLAGS_audio_encoding == "opus") {
const unsigned char* opus_data = (unsigned char*)chunk.audio().data();
size_t len = chunk.audio().length();
std::copy(opus_data, opus_data + len, std::back_inserter(opus_buffer));
audio_len += len;
}
}
grpc::Status rpc_status = reader->Finish();
auto end = std::chrono::steady_clock::now();
std::chrono::duration<double> elapsed_total = end - start;
LOG(INFO) << "Total streaming time: " << elapsed_total.count() << " s" << std::endl;

if (!rpc_status.ok()) {
// Report the RPC failure.
LOG(ERROR) << rpc_status.error_message() << std::endl;
LOG(ERROR) << "Input was: " << text_lines.size() << " lines." << std::endl;
LOG(ERROR) << "Input was: \'" << text << "\'" << std::endl;
return -1;
}

if (FLAGS_audio_encoding.empty() || FLAGS_audio_encoding == "pcm") {
::riva::utils::wav::Write(FLAGS_audio_file, rate, pcm_buffer.data(), pcm_buffer.size());
} else if (FLAGS_audio_encoding == "opus") {
Expand Down
Loading