diff --git a/.gitignore b/.gitignore index e4e5f6c8..60e4a961 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,3 @@ -*~ \ No newline at end of file +*~ +/bazel-* +/lyra/model_coeffs/_models.h diff --git a/README.md b/README.md index c1c39945..168a4b79 100644 --- a/README.md +++ b/README.md @@ -1,354 +1,43 @@ -# Lyra: a generative low bitrate speech codec +# lyra as a DLL -## What is Lyra? +This is originally forked from [Lyra](https://github.com/google/lyra), then after some small changes from the TOFRevive team, I tweaked it to be a simple DLL. +Lyra has amazing quality-to-bitrate ratio and has some features that I don't expose (at present) such as automatically generating pleasant noise when packets are late. -[Lyra](https://ai.googleblog.com/2021/08/soundstream-end-to-end-neural-audio.html) -is a high-quality, low-bitrate speech codec that makes voice communication -available even on the slowest networks. To do this it applies traditional codec -techniques while leveraging advances in machine learning (ML) with models -trained on thousands of hours of data to create a novel method for compressing -and transmitting voice signals. +## Compile (on Windows) -### Overview +I find that Python and Bazel are not to my liking, so I built everything in a Windows VM using VirtualBox. The steps are relatively straightforward and reproducible. -The basic architecture of the Lyra codec is quite simple. Features are extracted -from speech every 20ms and are then compressed for transmission at a desired -bitrate between 3.2kbps and 9.2kbps. On the other end, a generative model uses -those features to recreate the speech signal. +Download and install: https://aka.ms/vs/17/release/vs_BuildTools.exe and select Desktop Development with C++ +Download and install: https://github.com/git-for-windows/git/releases/download/v2.46.0.windows.1/Git-2.46.0-64-bit.exe just take all defaults. +Download and the Windows exe and rename it to bazel.exe https://github.com/bazelbuild/bazel/releases/tag/5.3.2 +Download and install: https://www.python.org/ftp/python/3.12.6/python-3.12.6-amd64.exe -Lyra harnesses the power of new natural-sounding generative models to maintain -the low bitrate of parametric codecs while achieving high quality, on par with -state-of-the-art waveform codecs used in most streaming and communication -platforms today. +Open a GIT BASH command window: + git clone https://github.com/google/lyra.git + cd lyra + export PATH=$PATH:/c/Users/User/AppData/Local/Programs/Python/Python312 + python -m pip install setuptools numpy six + python ./models_to_header.py + bazel.exe build -c opt --config=windows --action_env=PYTHON_BIN_PATH="/c/Users/User/AppData/Local/Programs/Python/Python312/python.exe" dll:dll -Computational complexity is reduced by using a cheaper convolutional generative -model called SoundStream, which enables Lyra to not only run on cloud servers, -but also on-device on low-end phones in real time (with a processing latency of -20ms). This whole system is then trained end-to-end on thousands of hours of -speech data with speakers in over 90 languages and optimized to accurately -recreate the input audio. +You can replace `-c opt` with `-c dbg` to build in debug mode (with asserts enabled). -Lyra is supported on Android, Linux, Mac and Windows. +Final file is at `./bazel-bin/dll/lyra_dll.dll` -## Prerequisites +## Compiling CLI examples -There are a few things you'll need to do to set up your computer to build Lyra. - -### Common setup - -Lyra is built using Google's build system, Bazel. Install it following these -[instructions](https://docs.bazel.build/versions/master/install.html). Bazel -verson 5.0.0 is required, and some Linux distributions may make an older version -available in their application repositories, so make sure you are using the -required version or newer. The latest version can be downloaded via -[Github](https://github.com/bazelbuild/bazel/releases). - -You will also need python3 and numpy installed. - -Lyra can be built from Linux using Bazel for an ARM Android target, or a Linux -target, as well as Mac and Windows for native targets. - -### Android requirements - -Building on android requires downloading a specific version of the android NDK -toolchain. If you develop with Android Studio already, you might not need to do -these steps if ANDROID_HOME and ANDROID_NDK_HOME are defined and pointing at the -right version of the NDK. - -1. Download command line tools from https://developer.android.com/studio -2. Unzip and cd to the directory -3. Check the available packages to install in case they don't match the - following steps. - - ```shell - bin/sdkmanager --sdk_root=$HOME/android/sdk --list - ``` - - Some systems will already have the java runtime set up. But if you see an - error here like `ERROR: JAVA_HOME is not set and no 'java' command could be - found on your PATH.`, this means you need to install the java runtime with - `sudo apt install default-jdk` first. You will also need to add `export - JAVA_HOME=/usr/lib/jvm/java-11-openjdk-amd64` (type `ls /usr/lib/jvm` to see - which path was installed) to your $HOME/.bashrc and reload it with `source - $HOME/.bashrc`. - -4. Install the r21 ndk, android sdk 30, and build tools: - - ```shell - bin/sdkmanager --sdk_root=$HOME/android/sdk --install "platforms;android-30" "build-tools;30.0.3" "ndk;21.4.7075529" - ``` - -5. Add the following to .bashrc (or export the variables) - - ```shell - export ANDROID_NDK_HOME=$HOME/android/sdk/ndk/21.4.7075529 - export ANDROID_HOME=$HOME/android/sdk - ``` - -6. Reload .bashrc (with `source $HOME/.bashrc`) - -## Building - -The building and running process differs slightly depending on the selected -platform. - -### Building for Linux - -You can build the cc_binaries with the default config. `encoder_main` is an -example of a file encoder. - -```shell -bazel build -c opt lyra/cli_example:encoder_main -``` - -You can run `encoder_main` to encode a test .wav file with some speech in it, -specified by `--input_path`. The `--output_dir` specifies where to write the -encoded (compressed) representation, and the desired bitrate can be specified -using the `--bitrate` flag. - -```shell -bazel-bin/lyra/cli_example/encoder_main --input_path=lyra/testdata/sample1_16kHz.wav --output_dir=$HOME/temp --bitrate=3200 -``` - -Similarly, you can build decoder_main and use it on the output of encoder_main -to decode the encoded data back into speech. - -```shell -bazel build -c opt lyra/cli_example:decoder_main -bazel-bin/lyra/cli_example/decoder_main --encoded_path=$HOME/temp/sample1_16kHz.lyra --output_dir=$HOME/temp/ --bitrate=3200 -``` - -Note: the default Bazel toolchain is automatically configured and likely uses -gcc/libstdc++ on Linux. This should be satisfactory for most users, but will -differ from the NDK toolchain, which uses clang/libc++. To use a custom clang -toolchain on Linux, see toolchain/README.md and .bazelrc. - -### Building for Android - -#### Android App - -There is an example APK target called `lyra_android_example` that you can build -after you have set up the NDK. - -This example is an app with a minimal GUI that has buttons for two options. One -option is to record from the microphone and encode/decode with Lyra so you can -test what Lyra would sound like for your voice. The other option runs a -benchmark that encodes and decodes in the background and prints the timings to -logcat. - -```shell -bazel build -c opt lyra/android_example:lyra_android_example --config=android_arm64 --copt=-DBENCHMARK -adb install bazel-bin/lyra/android_example/lyra_android_example.apk -``` - -After this you should see an app called "Lyra Example App". - -You can open it, and you will see a simple TextView that says the benchmark is -running, and when it finishes. - -Press "Record from microphone", say a few words, and then press "Encode and -decode to speaker". You should hear your voice being played back after being -coded with Lyra. - -If you press 'Benchmark', you should see something like the following in logcat -on a Pixel 6 Pro when running the benchmark: - -```shell -lyra_benchmark: feature_extractor: max: 1.836 ms min: 0.132 ms mean: 0.153 ms stdev: 0.042 ms -lyra_benchmark: quantizer_quantize: max: 1.042 ms min: 0.120 ms mean: 0.130 ms stdev: 0.028 ms -lyra_benchmark: quantizer_decode: max: 0.103 ms min: 0.026 ms mean: 0.029 ms stdev: 0.003 ms -lyra_benchmark: model_decode: max: 0.820 ms min: 0.191 ms mean: 0.212 ms stdev: 0.031 ms -lyra_benchmark: total: max: 2.536 ms min: 0.471 ms mean: 0.525 ms stdev: 0.088 ms -``` - -This shows that decoding a 50Hz frame (each frame is 20 milliseconds) takes -0.525 milliseconds on average. So decoding is performed at around 38 (20/0.525) -times faster than realtime. - -To build your own android app, you can either use the cc_library target outputs -to create a .so that you can use in your own build system. Or you can use it -with an -[`android_binary`](https://docs.bazel.build/versions/master/be/android.html) -rule within bazel to create an .apk file as in this example. - -There is a tutorial on building for android with Bazel in the -[bazel docs](https://docs.bazel.build/versions/master/android-ndk.html). - -#### Android command-line binaries - -There are also the binary targets that you can use to experiment with encoding -and decoding .wav files. - -You can build the example cc_binary targets with: - -```shell -bazel build -c opt lyra/cli_example:encoder_main --config=android_arm64 -bazel build -c opt lyra/cli_example:decoder_main --config=android_arm64 +Lyra comes with two sample CLI programs to test out encode/decode, you can compile and use them as follows: ``` +bazel build -c opt --action_env PYTHON_BIN_PATH="C:\\Python311\\python.exe" lyra/cli_example:encoder_main -This builds an executable binary that can be run on android 64-bit arm devices -(not an android app). You can then push it to your android device and run it as -a binary through the shell. - -```shell -# Push the binary and the data it needs, including the model and .wav files: -adb push bazel-bin/lyra/cli_example/encoder_main /data/local/tmp/ -adb push bazel-bin/lyra/cli_example/decoder_main /data/local/tmp/ -adb push lyra/model_coeffs/ /data/local/tmp/ -adb push lyra/testdata/ /data/local/tmp/ - -adb shell -cd /data/local/tmp -./encoder_main --model_path=/data/local/tmp/model_coeffs --output_dir=/data/local/tmp --input_path=testdata/sample1_16kHz.wav -./decoder_main --model_path=/data/local/tmp/model_coeffs --output_dir=/data/local/tmp --encoded_path=sample1_16kHz.lyra +bazel-bin/lyra/cli_example/encoder_main.exe --input_path=lyra/testdata/sample1_16kHz.wav --output_dir=%temp% --bitrate=9200 +bazel-bin/lyra/cli_example/decoder_main.exe --encoded_path=%temp%/sample1_16kHz.lyra --output_dir=%temp%/ --bitrate=9200 ``` -The encoder_main/decoder_main as above should also work. - -### Building for Mac - -You will need to install the XCode command line tools in addition to the -prerequisites common to all platforms. XCode setup is a required step for using -Bazel on Mac. See this [guide](https://bazel.build/install/os-x) for how to -install XCode command line tools. Lyra has been built successfully using XCode -13.3. - -You can follow the instructions in the [Building for Linux](#building-for-linux) -section once this is completed. - -### Building for Windows - -You will need to install Build Tools for Visual Studio 2019 in addition to the -prerequisites common to all platforms. Visual Studio setup is a required step -for building C++ for Bazel on Windows. See this -[guide](https://bazel.build/install/windows) for how to install MSVC. You may -also need to install python 3 support, which is also described in the guide. - -You can follow the instructions in the [Building for Linux](#building-for-linux) -section once this is completed. - -## API - -For integrating Lyra into any project only two APIs are relevant: -[LyraEncoder](lyra/lyra_encoder.h) and [LyraDecoder](lyra/lyra_decoder.h). - -> DISCLAIMER: At this time Lyra's API and bit-stream are **not** guaranteed to -> be stable and might change in future versions of the code. - -On the sending side, `LyraEncoder` can be used to encode an audio stream using -the following interface: - -```cpp -class LyraEncoder : public LyraEncoderInterface { - public: - static std::unique_ptr Create( - int sample_rate_hz, int num_channels, int bitrate, bool enable_dtx, - const ghc::filesystem::path& model_path); - - std::optional> Encode( - const absl::Span audio) override; - - bool set_bitrate(int bitrate) override; - - int sample_rate_hz() const override; - - int num_channels() const override; +## Running unit tests - int bitrate() const override; - - int frame_rate() const override; -}; +Running the built-in unit tests of Lyra might be useful, as we did some changes to statically build model files into the binaries themselves. So you can run them with: ``` - -The static `Create` method instantiates a `LyraEncoder` with the desired sample -rate in Hertz, number of channels and bitrate, as long as those parameters are -supported (see `lyra_encoder.h` for supported parameters). Otherwise it returns -a nullptr. The `Create` method also needs to know if DTX should be enabled and -where the model weights are stored. It also checks that these weights exist and -are compatible with the current Lyra version. - -Given a `LyraEncoder`, any audio stream can be compressed using the `Encode` -method. The provided span of int16-formatted samples is assumed to contain 20ms -of data at the sample rate chosen at `Create` time. As long as this condition is -met the `Encode` method returns the encoded packet as a vector of bytes that is -ready to be stored or transmitted over the network. - -The bitrate can be dynamically modified using the `set_bitrate` setter. It -returns true if the desired bitrate is supported and correctly set. - -The rest of the `LyraEncoder` methods are just getters for the different -predetermined parameters. - -On the receiving end, `LyraDecoder` can be used to decode the encoded packet -using the following interface: - -```cpp -class LyraDecoder : public LyraDecoderInterface { - public: - static std::unique_ptr Create( - int sample_rate_hz, int num_channels, - const ghc::filesystem::path& model_path); - - bool SetEncodedPacket(absl::Span encoded) override; - - std::optional> DecodeSamples(int num_samples) override; - - int sample_rate_hz() const override; - - int num_channels() const override; - - int frame_rate() const override; - - bool is_comfort_noise() const override; -}; +bazel test --action_env PYTHON_BIN_PATH="C:\\Python311\\python.exe" //lyra:all ``` -Once again, the static `Create` method instantiates a `LyraDecoder` with the -desired sample rate in Hertz and number of channels, as long as those parameters -are supported. Else it returns a `nullptr`. These parameters don't need to be -the same as the ones in `LyraEncoder`. And once again, the `Create` method also -needs to know where the model weights are stored. It also checks that these -weights exist and are compatible with the current Lyra version. - -Given a `LyraDecoder`, any packet can be decoded by first feeding it into -`SetEncodedPacket`, which returns true if the provided span of bytes is a valid -Lyra-encoded packet. - -Then the int16-formatted samples can be obtained by calling `DecodeSamples`. If -there isn't a packet available, but samples still need to be generated, the -decoder might switch to a comfort noise generation mode, which can be checked -using `is_comfort_noise`. - -The rest of the `LyraDecoder` methods are just getters for the different -predetermined parameters. - -For an example on how to use `LyraEncoder` and `LyraDecoder` to encode and -decode a stream of audio, please refer to the -[integration test](lyra/lyra_integration_test.cc). - -## License - -Use of this source code is governed by a Apache v2.0 license that can be found -in the LICENSE file. - -## Papers - -1. Kleijn, W. B., Lim, F. S., Luebs, A., Skoglund, J., Stimberg, F., Wang, Q., - & Walters, T. C. (2018, April). - [Wavenet based low rate speech coding](https://arxiv.org/pdf/1712.01120). In - 2018 IEEE international conference on acoustics, speech and signal - processing (ICASSP) (pp. 676-680). IEEE. -2. Denton, T., Luebs, A., Chinen, M., Lim, F. S., Storus, A., Yeh, H., Kleijn, - W. B., & Skoglund, J. (2020, November). - [Handling Background Noise in Neural Speech Generation](https://arxiv.org/pdf/2102.11906). - In 2020 54th Asilomar Conference on Signals, Systems, and Computers (pp. - 667-671). IEEE. -3. Kleijn, W. B., Storus, A., Chinen, M., Denton, T., Lim, F. S., Luebs, A., - Skoglund, J., & Yeh, H. (2021, June). - [Generative speech coding with predictive variance regularization](https://arxiv.org/pdf/2102.09660). - In ICASSP 2021-2021 IEEE International Conference on Acoustics, Speech and - Signal Processing (ICASSP) (pp. 6478-6482). IEEE. -4. Zeghidour, N., Luebs, A., Omran, A., Skoglund, J., & Tagliasacchi, M. - (2021). - [SoundStream: An end-to-end neural audio codec](https://arxiv.org/pdf/2107.03312). - IEEE/ACM Transactions on Audio, Speech, and Language Processing. diff --git a/WORKSPACE b/WORKSPACE index 388661cb..f938a5c3 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -163,6 +163,16 @@ maven_install( ) +load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") + +http_archive( + name = "gflags", + urls = ["https://github.com/gflags/gflags/archive/refs/tags/v2.2.2.tar.gz"], + strip_prefix = "gflags-2.2.2", + sha256 = "34af2f15cf7367513b352bdcd2493ab14ce43692d2dcd9dfc499492966c64dcf", +) + + # Begin Tensorflow WORKSPACE subset required for TFLite git_repository( diff --git a/dll/BUILD b/dll/BUILD new file mode 100644 index 00000000..361e59f9 --- /dev/null +++ b/dll/BUILD @@ -0,0 +1,33 @@ +cc_binary( + name = "lyra_dll", + srcs = [ + "dllmain.cc", + ], + #data = [":tflite_testdata"], + linkopts = select({ + "//lyra:android_config": ["-landroid"], + "//conditions:default": [], + }), + deps = [ + "//lyra:lyra_config", + "//lyra:lyra_encoder", + "//lyra:lyra_decoder", + "//lyra:wav_utils", + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/flags:parse", + "@com_google_absl//absl/flags:usage", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/time", + "@com_google_absl//absl/types:span", + "@com_google_glog//:glog", + "@gulrak_filesystem//:filesystem", + ], + linkshared = 1, + copts = ["/DCOMPILING_DLL"], + target_compatible_with = [ + "@platforms//cpu:x86_64", + "@platforms//os:windows", + ], +) \ No newline at end of file diff --git a/dll/dllmain.cc b/dll/dllmain.cc new file mode 100644 index 00000000..da9755a8 --- /dev/null +++ b/dll/dllmain.cc @@ -0,0 +1,96 @@ +#ifndef WIN32_LEAN_AND_MEAN +#define WIN32_LEAN_AND_MEAN +#endif +#include "Windows.h" + +#include "lyra/lyra_config.h" +#include "lyra/lyra_encoder.h" +#include "lyra/lyra_decoder.h" +#include "lyra/model_coeffs/_models.h" + +#define BYTES_PER_SAMPLE 2 + +BOOL APIENTRY DllMain(HMODULE hModule, DWORD ul_reason_for_call, LPVOID lpReserved) +{ + switch (ul_reason_for_call) + { + case DLL_PROCESS_ATTACH: + case DLL_THREAD_ATTACH: + case DLL_THREAD_DETACH: + case DLL_PROCESS_DETACH: + break; + } + return TRUE; +} + +std::unique_ptr m_Encoder = nullptr; +std::unique_ptr m_Decoder = nullptr; + +extern "C" __declspec(dllexport) bool Initialize() +{ + const int samplerate = 16000; + const int bitrate = 3200; + + const chromemedia::codec::LyraModels models = GetEmbeddedLyraModels(); + + if (!m_Encoder) + { + m_Encoder = chromemedia::codec::LyraEncoder::Create(samplerate, 1, bitrate, false, models); + } + if (!m_Decoder) + { + m_Decoder = chromemedia::codec::LyraDecoder::Create(samplerate, 1, models); + } + return m_Encoder != nullptr && m_Decoder != nullptr; +} + +extern "C" __declspec(dllexport) void Shutdown() +{ + m_Encoder.reset(); + m_Decoder.reset(); +} + +extern "C" __declspec(dllexport) void Encode(const int16_t* uncompressed, size_t uncompressed_size, uint8_t* compressed, size_t compressed_size) +{ + const int num_samples_per_packet = m_Encoder->sample_rate_hz() / m_Encoder->frame_rate(); + const int raw_frame_size = num_samples_per_packet * BYTES_PER_SAMPLE; + + assert(uncompressed_size >= num_samples_per_packet); + + std::vector uncompressed_vector(uncompressed, uncompressed + num_samples_per_packet); + std::optional> encoded = m_Encoder->Encode(uncompressed_vector); + + if (!encoded.has_value()) + { + return; + } + + assert(encoded->size() == chromemedia::codec::BitrateToPacketSize(m_Encoder->bitrate())); + assert(compressed_size >= encoded->size()); + + memcpy_s(compressed, compressed_size, encoded->data(), encoded->size()); +} + +extern "C" __declspec(dllexport) void Decode(const int8_t* compressed, size_t compressed_size, uint16_t* uncompressed, size_t uncompressed_size) +{ + const int num_samples_per_packet = m_Encoder->sample_rate_hz() / m_Encoder->frame_rate(); + const int packet_size = chromemedia::codec::BitrateToPacketSize(m_Encoder->bitrate()); + + assert(compressed_size == packet_size); + + bool valid = m_Decoder->SetEncodedPacket(absl::MakeSpan(reinterpret_cast(compressed), compressed_size)); + assert(valid == true); + if (!valid) return; + + std::optional> decoded = m_Decoder->DecodeSamples(num_samples_per_packet); + if (!decoded.has_value()) + { + assert(decoded.has_value()); + return; + } + + assert(decoded->size() == num_samples_per_packet); + assert(uncompressed_size >= decoded->size()); + + memcpy_s(uncompressed, uncompressed_size * sizeof(uint16_t), decoded->data(), decoded->size() * sizeof(uint16_t)); +} diff --git a/lyra/BUILD b/lyra/BUILD index a64673ce..3c7b9f28 100644 --- a/lyra/BUILD +++ b/lyra/BUILD @@ -32,18 +32,11 @@ config_setting( values = {"crosstool_top": "//external:android/crosstool"}, ) -cc_library( - name = "architecture_utils", - hdrs = ["architecture_utils.h"], - deps = ["@gulrak_filesystem//:filesystem"], -) - cc_library( name = "lyra_benchmark_lib", srcs = ["lyra_benchmark_lib.cc"], hdrs = ["lyra_benchmark_lib.h"], deps = [ - ":architecture_utils", ":dsp_utils", ":feature_extractor_interface", ":generative_model_interface", @@ -771,6 +764,7 @@ cc_library( ], hdrs = [ "tflite_model_wrapper.h", + "lyra_embedded_models.h", ], deps = [ "@com_google_absl//absl/memory", @@ -814,7 +808,10 @@ cc_test( cc_test( name = "tflite_model_wrapper_test", - srcs = ["tflite_model_wrapper_test.cc"], + srcs = [ + "tflite_model_wrapper_test.cc", + "model_coeffs/_models.h", + ], data = ["model_coeffs/lyragan.tflite"], deps = [ ":tflite_model_wrapper", diff --git a/lyra/architecture_utils.h b/lyra/architecture_utils.h deleted file mode 100644 index c4bddbc5..00000000 --- a/lyra/architecture_utils.h +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Copyright 2021 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef LYRA_ARCHITECTURE_UTILS_H_ -#define LYRA_ARCHITECTURE_UTILS_H_ - -// Placeholder for get runfiles header. -#include "include/ghc/filesystem.hpp" - -namespace chromemedia { -namespace codec { - -ghc::filesystem::path GetCompleteArchitecturePath( - const ghc::filesystem::path& model_path) { - return model_path; -} - -} // namespace codec -} // namespace chromemedia - -#endif // LYRA_ARCHITECTURE_UTILS_H_ diff --git a/lyra/cli_example/BUILD b/lyra/cli_example/BUILD index 8c4a9c3b..f7ff3a95 100644 --- a/lyra/cli_example/BUILD +++ b/lyra/cli_example/BUILD @@ -117,7 +117,6 @@ cc_binary( }), deps = [ ":encoder_main_lib", - "//lyra:architecture_utils", "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/flags:parse", "@com_google_absl//absl/flags:usage", @@ -139,7 +138,6 @@ cc_binary( }), deps = [ ":decoder_main_lib", - "//lyra:architecture_utils", "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/flags:parse", "@com_google_absl//absl/flags:usage", diff --git a/lyra/cli_example/decoder_main.cc b/lyra/cli_example/decoder_main.cc index f29477ed..559ace88 100644 --- a/lyra/cli_example/decoder_main.cc +++ b/lyra/cli_example/decoder_main.cc @@ -22,9 +22,10 @@ #include "absl/strings/string_view.h" #include "glog/logging.h" // IWYU pragma: keep #include "include/ghc/filesystem.hpp" -#include "lyra/architecture_utils.h" #include "lyra/cli_example/decoder_main_lib.h" +#include "lyra/model_coeffs/_models.h" + ABSL_FLAG(std::string, encoded_path, "", "Complete path to the file containing the encoded features."); ABSL_FLAG(std::string, output_dir, "", @@ -50,11 +51,6 @@ ABSL_FLAG(chromemedia::codec::PacketLossPattern, fixed_packet_loss_pattern, "bursts will be rounded up to the nearest packet duration boundary. " "If this flag contains a nonzero number of values we ignore " "|packet_loss_rate| and |average_burst_length|."); -ABSL_FLAG(std::string, model_path, "lyra/model_coeffs", - "Path to directory containing TFLite files. For mobile this is the " - "absolute path, like " - "'/data/local/tmp/lyra/model_coeffs/'." - " For desktop this is the path relative to the binary."); int main(int argc, char** argv) { absl::SetProgramUsageMessage(argv[0]); @@ -71,9 +67,7 @@ int main(int argc, char** argv) { const float average_burst_length = absl::GetFlag(FLAGS_average_burst_length); const chromemedia::codec::PacketLossPattern fixed_packet_loss_pattern = absl::GetFlag(FLAGS_fixed_packet_loss_pattern); - const ghc::filesystem::path model_path = - chromemedia::codec::GetCompleteArchitecturePath( - absl::GetFlag(FLAGS_model_path)); + const chromemedia::codec::LyraModels models = GetEmbeddedLyraModels(); if (!fixed_packet_loss_pattern.starts_.empty()) { LOG(INFO) << "Using fixed packet loss pattern instead of gilbert model."; } @@ -102,7 +96,7 @@ int main(int argc, char** argv) { if (!chromemedia::codec::DecodeFile(encoded_path, output_path, sample_rate_hz, bitrate, randomize_num_samples_requested, packet_loss_rate, average_burst_length, - fixed_packet_loss_pattern, model_path)) { + fixed_packet_loss_pattern, models)) { LOG(ERROR) << "Could not decode " << encoded_path; return -1; } diff --git a/lyra/cli_example/decoder_main_lib.cc b/lyra/cli_example/decoder_main_lib.cc index 61974393..39a97e91 100644 --- a/lyra/cli_example/decoder_main_lib.cc +++ b/lyra/cli_example/decoder_main_lib.cc @@ -146,8 +146,8 @@ bool DecodeFile(const ghc::filesystem::path& encoded_path, int bitrate, bool randomize_num_samples_requested, float packet_loss_rate, float average_burst_length, const PacketLossPattern& fixed_packet_loss_pattern, - const ghc::filesystem::path& model_path) { - auto decoder = LyraDecoder::Create(sample_rate_hz, kNumChannels, model_path); + const LyraModels& models) { + auto decoder = LyraDecoder::Create(sample_rate_hz, kNumChannels, models); if (decoder == nullptr) { LOG(ERROR) << "Could not create lyra decoder."; return false; diff --git a/lyra/cli_example/decoder_main_lib.h b/lyra/cli_example/decoder_main_lib.h index 81f4c1c0..562e5080 100644 --- a/lyra/cli_example/decoder_main_lib.h +++ b/lyra/cli_example/decoder_main_lib.h @@ -55,7 +55,6 @@ bool DecodeFeatures(const std::vector& packet_stream, int packet_size, std::vector* decoded_audio); // Decodes an encoded features file into a wav file. -// Uses the model and quant files located under |model_path|. // Given the file /tmp/lyra/file1.lyra exists and is a valid encoded file. For: // |encoded_path| = "/tmp/lyra/file1.lyra" // |output_path| = "/tmp/lyra/file1_decoded.lyra" @@ -66,7 +65,7 @@ bool DecodeFile(const ghc::filesystem::path& encoded_path, int bitrate, bool randomize_num_samples_requested, float packet_loss_rate, float average_burst_length, const PacketLossPattern& fixed_packet_loss_pattern, - const ghc::filesystem::path& model_path); + const chromemedia::codec::LyraModels& models); } // namespace codec } // namespace chromemedia diff --git a/lyra/cli_example/decoder_main_lib_test.cc b/lyra/cli_example/decoder_main_lib_test.cc index cf662eb9..99abc785 100644 --- a/lyra/cli_example/decoder_main_lib_test.cc +++ b/lyra/cli_example/decoder_main_lib_test.cc @@ -40,7 +40,7 @@ class DecoderMainLibTest : public testing::TestWithParam { DecoderMainLibTest() : output_dir_(ghc::filesystem::path(testing::TempDir()) / "output/"), testdata_dir_(ghc::filesystem::current_path() / kTestdataDir), - model_path_(ghc::filesystem::current_path() / kExportedModelPath), + models_(GetEmbeddedLyraModels()), sample_rate_hz_(GetParam()), num_samples_in_packet_(GetNumSamplesPerHop(sample_rate_hz_)) {} @@ -67,7 +67,7 @@ class DecoderMainLibTest : public testing::TestWithParam { const ghc::filesystem::path output_dir_; const ghc::filesystem::path testdata_dir_; - const ghc::filesystem::path model_path_; + const LyraModels models_; ghc::filesystem::path input_path_; ghc::filesystem::path output_path_; const int sample_rate_hz_; @@ -80,7 +80,7 @@ TEST_P(DecoderMainLibTest, NoEncodedPacket) { input_path_, output_path_, sample_rate_hz_, /*bitrate=*/3200, /*randomize_num_samples_requested=*/false, /*packet_loss_rate=*/0.f, - /*average_burst_length=*/1.f, PacketLossPattern({}, {}), model_path_)); + /*average_burst_length=*/1.f, PacketLossPattern({}, {}), models_)); } TEST_P(DecoderMainLibTest, OneEncodedPacket) { @@ -89,7 +89,7 @@ TEST_P(DecoderMainLibTest, OneEncodedPacket) { input_path_, output_path_, sample_rate_hz_, /*bitrate=*/6000, /*randomize_num_samples_requested=*/false, /*packet_loss_rate=*/0.f, - /*average_burst_length=*/1.f, PacketLossPattern({}, {}), model_path_)); + /*average_burst_length=*/1.f, PacketLossPattern({}, {}), models_)); EXPECT_EQ(NumSamplesInWavFile(output_path_), num_samples_in_packet_); } @@ -100,7 +100,7 @@ TEST_P(DecoderMainLibTest, RandomizeSampleRequests) { input_path_, output_path_, sample_rate_hz_, /*bitrate=*/6000, /*randomize_num_samples_requested=*/true, /*packet_loss_rate=*/0.f, - /*average_burst_length=*/1.f, PacketLossPattern({}, {}), model_path_)); + /*average_burst_length=*/1.f, PacketLossPattern({}, {}), models_)); EXPECT_EQ(NumSamplesInWavFile(output_path_), num_samples_in_packet_); } @@ -111,7 +111,7 @@ TEST_P(DecoderMainLibTest, FileDoesNotExist) { input_path_, output_path_, sample_rate_hz_, /*bitrate=*/6000, /*randomize_num_samples_requested=*/false, /*packet_loss_rate=*/0.f, - /*average_burst_length=*/1.f, PacketLossPattern({}, {}), model_path_)); + /*average_burst_length=*/1.f, PacketLossPattern({}, {}), models_)); } // Tests an encoded features file with less than 1 packet's worth of data. @@ -122,7 +122,7 @@ TEST_P(DecoderMainLibTest, IncompleteEncodedPacket) { input_path_, output_path_, sample_rate_hz_, /*bitrate=*/6000, /*randomize_num_samples_requested=*/false, /*packet_loss_rate=*/0.f, - /*average_burst_length=*/1.f, PacketLossPattern({}, {}), model_path_)); + /*average_burst_length=*/1.f, PacketLossPattern({}, {}), models_)); } TEST_P(DecoderMainLibTest, TwoEncodedPacketsWithPacketLoss) { @@ -133,14 +133,14 @@ TEST_P(DecoderMainLibTest, TwoEncodedPacketsWithPacketLoss) { input_path_, output_path_, sample_rate_hz_, /*bitrate=*/6000, /*randomize_num_samples_requested=*/false, /*packet_loss_rate=*/0.5f, - /*average_burst_length=*/2.f, PacketLossPattern({}, {}), model_path_)); + /*average_burst_length=*/2.f, PacketLossPattern({}, {}), models_)); EXPECT_EQ(NumSamplesInWavFile(output_path_), expected_num_samples); EXPECT_TRUE(DecodeFile( input_path_, output_path_, sample_rate_hz_, /*bitrate=*/6000, /*randomize_num_samples_requested=*/false, /*packet_loss_rate=*/0.9f, - /*average_burst_length=*/10.f, PacketLossPattern({}, {}), model_path_)); + /*average_burst_length=*/10.f, PacketLossPattern({}, {}), models_)); EXPECT_EQ(NumSamplesInWavFile(output_path_), expected_num_samples); } @@ -152,7 +152,7 @@ TEST_P(DecoderMainLibTest, TwoEncodedPacketsWithFixedPacketLoss) { input_path_, output_path_, sample_rate_hz_, /*bitrate=*/6000, /*randomize_num_samples_requested=*/false, /*packet_loss_rate=*/0.9f, - /*average_burst_length=*/10.f, PacketLossPattern({1}, {0}), model_path_)); + /*average_burst_length=*/10.f, PacketLossPattern({1}, {0}), models_)); EXPECT_EQ(NumSamplesInWavFile(output_path_), expected_num_samples); EXPECT_TRUE(DecodeFile(input_path_, output_path_, sample_rate_hz_, @@ -160,7 +160,7 @@ TEST_P(DecoderMainLibTest, TwoEncodedPacketsWithFixedPacketLoss) { /*randomize_num_samples_requested=*/false, /*packet_loss_rate=*/0.9f, /*average_burst_length=*/10.f, - PacketLossPattern({0}, {100}), model_path_)); + PacketLossPattern({0}, {100}), models_)); EXPECT_EQ(NumSamplesInWavFile(output_path_), expected_num_samples); } diff --git a/lyra/cli_example/encoder_main.cc b/lyra/cli_example/encoder_main.cc index d395f156..48edf728 100644 --- a/lyra/cli_example/encoder_main.cc +++ b/lyra/cli_example/encoder_main.cc @@ -21,9 +21,10 @@ #include "absl/strings/string_view.h" #include "glog/logging.h" // IWYU pragma: keep #include "include/ghc/filesystem.hpp" -#include "lyra/architecture_utils.h" #include "lyra/cli_example/encoder_main_lib.h" +#include "lyra/model_coeffs/_models.h" + ABSL_FLAG(std::string, input_path, "", "Complete path to the WAV file to be encoded."); ABSL_FLAG(std::string, output_dir, "", @@ -40,11 +41,6 @@ ABSL_FLAG(bool, enable_preprocessing, false, ABSL_FLAG(bool, enable_dtx, false, "Enables discontinuous transmission (DTX). DTX does not send packets " "when noise is detected."); -ABSL_FLAG(std::string, model_path, "lyra/model_coeffs", - "Path to directory containing TFLite files. For mobile this is the " - "absolute path, like " - "'/data/local/tmp/lyra/model_coeffs/'." - " For desktop this is the path relative to the binary."); int main(int argc, char** argv) { absl::SetProgramUsageMessage(argv[0]); @@ -52,12 +48,10 @@ int main(int argc, char** argv) { const ghc::filesystem::path input_path(absl::GetFlag(FLAGS_input_path)); const ghc::filesystem::path output_dir(absl::GetFlag(FLAGS_output_dir)); - const ghc::filesystem::path model_path = - chromemedia::codec::GetCompleteArchitecturePath( - absl::GetFlag(FLAGS_model_path)); const int bitrate = absl::GetFlag(FLAGS_bitrate); const bool enable_preprocessing = absl::GetFlag(FLAGS_enable_preprocessing); const bool enable_dtx = absl::GetFlag(FLAGS_enable_dtx); + const chromemedia::codec::LyraModels models = GetEmbeddedLyraModels(); if (input_path.empty()) { LOG(ERROR) << "Flag --input_path not set."; @@ -82,7 +76,7 @@ int main(int argc, char** argv) { if (!chromemedia::codec::EncodeFile(input_path, output_path, bitrate, enable_preprocessing, enable_dtx, - model_path)) { + models)) { LOG(ERROR) << "Failed to encode " << input_path; return -1; } diff --git a/lyra/cli_example/encoder_main_lib.cc b/lyra/cli_example/encoder_main_lib.cc index 30ee0348..706ca540 100644 --- a/lyra/cli_example/encoder_main_lib.cc +++ b/lyra/cli_example/encoder_main_lib.cc @@ -42,13 +42,13 @@ namespace codec { // starting at index 0. bool EncodeWav(const std::vector& wav_data, int num_channels, int sample_rate_hz, int bitrate, bool enable_preprocessing, - bool enable_dtx, const ghc::filesystem::path& model_path, + bool enable_dtx, const chromemedia::codec::LyraModels& models, std::vector* encoded_features) { auto encoder = LyraEncoder::Create(/*sample_rate_hz=*/sample_rate_hz, /*num_channels=*/num_channels, /*bitrate=*/bitrate, /*enable_dtx=*/enable_dtx, - /*model_path=*/model_path); + /*models=*/models); if (encoder == nullptr) { LOG(ERROR) << "Could not create lyra encoder."; return false; @@ -98,7 +98,7 @@ bool EncodeWav(const std::vector& wav_data, int num_channels, bool EncodeFile(const ghc::filesystem::path& wav_path, const ghc::filesystem::path& output_path, int bitrate, bool enable_preprocessing, bool enable_dtx, - const ghc::filesystem::path& model_path) { + const chromemedia::codec::LyraModels& models) { // Reads the entire wav file into memory. absl::StatusOr read_wav_result = Read16BitWavFileToVector(wav_path.string()); @@ -112,7 +112,7 @@ bool EncodeFile(const ghc::filesystem::path& wav_path, std::vector encoded_features; if (!EncodeWav(read_wav_result->samples, read_wav_result->num_channels, read_wav_result->sample_rate_hz, bitrate, enable_preprocessing, - enable_dtx, model_path, &encoded_features)) { + enable_dtx, models, &encoded_features)) { LOG(ERROR) << "Unable to encode features for file " << wav_path; return false; } diff --git a/lyra/cli_example/encoder_main_lib.h b/lyra/cli_example/encoder_main_lib.h index 69d4852b..94031468 100644 --- a/lyra/cli_example/encoder_main_lib.h +++ b/lyra/cli_example/encoder_main_lib.h @@ -22,23 +22,23 @@ #include "include/ghc/filesystem.hpp" +#include "lyra/lyra_embedded_models.h" + namespace chromemedia { namespace codec { // Encodes a vector of wav_data into encoded_features. -// Uses the quant files located under |model_path|. bool EncodeWav(const std::vector& wav_data, int num_channels, int sample_rate_hz, int bitrate, bool enable_preprocessing, - bool enable_dtx, const ghc::filesystem::path& model_path, + bool enable_dtx, const chromemedia::codec::LyraModels& models, std::vector* encoded_features); // Encodes a wav file into an encoded feature file. Encodes num_samples from the // file at |wav_path| and writes the encoded features out to |output_path|. -// Uses the quant files located under |model_path|. bool EncodeFile(const ghc::filesystem::path& wav_path, const ghc::filesystem::path& output_path, int bitrate, bool enable_preprocessing, bool enable_dtx, - const ghc::filesystem::path& model_path); + const chromemedia::codec::LyraModels& models); } // namespace codec } // namespace chromemedia diff --git a/lyra/cli_example/encoder_main_lib_test.cc b/lyra/cli_example/encoder_main_lib_test.cc index d40f3d4e..aa8fb258 100644 --- a/lyra/cli_example/encoder_main_lib_test.cc +++ b/lyra/cli_example/encoder_main_lib_test.cc @@ -38,7 +38,7 @@ class EncoderMainLibTest : public testing::Test { EncoderMainLibTest() : output_dir_(ghc::filesystem::path(testing::TempDir()) / "output"), testdata_dir_(ghc::filesystem::current_path() / kTestdataDir), - model_path_(ghc::filesystem::current_path() / "lyra/model_coeffs") {} + models_(GetEmbeddedLyraModels()) {} void SetUp() override { std::error_code error_code; @@ -54,7 +54,7 @@ class EncoderMainLibTest : public testing::Test { const ghc::filesystem::path output_dir_; const ghc::filesystem::path testdata_dir_; - const ghc::filesystem::path model_path_; + const LyraModels models_; }; TEST_F(EncoderMainLibTest, WavFileNotFound) { @@ -63,7 +63,7 @@ TEST_F(EncoderMainLibTest, WavFileNotFound) { EXPECT_FALSE(EncodeFile(kNonExistentWav, kOutputEncoded, /*bitrate=*/3200, /*enable_preprocessing=*/false, - /*enable_dtx=*/false, model_path_)); + /*enable_dtx=*/false, models_)); std::error_code error_code; EXPECT_FALSE(ghc::filesystem::is_regular_file(kOutputEncoded, error_code)); @@ -75,7 +75,7 @@ TEST_F(EncoderMainLibTest, EncodeSingleWavFiles) { const auto kOutputEncoded = (output_dir_ / wav_file).concat(".lyra"); EXPECT_TRUE(EncodeFile(kInputWavepath, kOutputEncoded, /*bitrate=*/3200, /*enable_preprocessing=*/false, - /*enable_dtx=*/false, model_path_)); + /*enable_dtx=*/false, models_)); } } diff --git a/lyra/lyra_benchmark.cc b/lyra/lyra_benchmark.cc index 4f5b19ab..a3a9b58d 100644 --- a/lyra/lyra_benchmark.cc +++ b/lyra/lyra_benchmark.cc @@ -25,12 +25,6 @@ ABSL_FLAG(int, num_cond_vectors, 2000, "stack / network. " "Equivalent to the number of calls to Precompute and Run."); -ABSL_FLAG(std::string, model_path, "lyra/model_coeffs", - "Path to directory containing TFLite files. For mobile this is the " - "absolute path, like " - "'/data/local/tmp/lyra/model_coeffs/'." - " For desktop this is the path relative to the binary."); - ABSL_FLAG(bool, benchmark_feature_extraction, true, "Whether to benchmark the feature extraction."); @@ -45,7 +39,7 @@ int main(int argc, char** argv) { absl::ParseCommandLine(argc, argv); return chromemedia::codec::lyra_benchmark( - absl::GetFlag(FLAGS_num_cond_vectors), absl::GetFlag(FLAGS_model_path), + absl::GetFlag(FLAGS_num_cond_vectors), absl::GetFlag(FLAGS_benchmark_feature_extraction), absl::GetFlag(FLAGS_benchmark_quantizer), absl::GetFlag(FLAGS_benchmark_generative_model)); diff --git a/lyra/lyra_benchmark_lib.cc b/lyra/lyra_benchmark_lib.cc index 5d6fc994..ee3b18d9 100644 --- a/lyra/lyra_benchmark_lib.cc +++ b/lyra/lyra_benchmark_lib.cc @@ -39,13 +39,14 @@ #include "audio/dsp/signal_vector_util.h" #include "glog/logging.h" // IWYU pragma: keep #include "include/ghc/filesystem.hpp" -#include "lyra/architecture_utils.h" #include "lyra/dsp_utils.h" #include "lyra/feature_extractor_interface.h" #include "lyra/generative_model_interface.h" #include "lyra/lyra_components.h" #include "lyra/lyra_config.h" +#include "model_coeffs/_models.h" + #ifdef BENCHMARK #include "absl/base/thread_annotations.h" #include "absl/time/clock.h" @@ -197,7 +198,6 @@ void PrintStatsAndWriteCSV(const std::vector& timings, } int lyra_benchmark(const int num_cond_vectors, - const std::string& model_base_path, const bool benchmark_feature_extraction, const bool benchmark_quantizer, const bool benchmark_generative_model) { @@ -207,18 +207,18 @@ int lyra_benchmark(const int num_cond_vectors, } const int num_samples_per_hop = GetNumSamplesPerHop(kInternalSampleRateHz); - const std::string model_path = GetCompleteArchitecturePath(model_base_path); + const auto models = GetEmbeddedLyraModels(); std::unique_ptr feature_extractor = - benchmark_feature_extraction ? CreateFeatureExtractor(model_path) + benchmark_feature_extraction ? CreateFeatureExtractor(models) : nullptr; std::unique_ptr vector_quantizer = - benchmark_quantizer ? CreateQuantizer(model_path) : nullptr; + benchmark_quantizer ? CreateQuantizer(models) : nullptr; std::unique_ptr model = benchmark_generative_model - ? CreateGenerativeModel(kNumFeatures, model_path) + ? CreateGenerativeModel(kNumFeatures, models) : nullptr; std::vector feature_extractor_timings; diff --git a/lyra/lyra_benchmark_lib.h b/lyra/lyra_benchmark_lib.h index 0b843344..19f9fe90 100644 --- a/lyra/lyra_benchmark_lib.h +++ b/lyra/lyra_benchmark_lib.h @@ -31,7 +31,7 @@ struct TimingStats { float standard_deviation; }; -int lyra_benchmark(int num_cond_vectors, const std::string& model_base_path, +int lyra_benchmark(int num_cond_vectors, bool benchmark_feature_extraction, bool benchmark_quantizer, bool benchmark_generative_model); diff --git a/lyra/lyra_components.cc b/lyra/lyra_components.cc index ed35b65a..8eb1ea0e 100644 --- a/lyra/lyra_components.cc +++ b/lyra/lyra_components.cc @@ -40,18 +40,18 @@ constexpr int kMaxNumPacketBits = 184; } // namespace std::unique_ptr CreateQuantizer( - const ghc::filesystem::path& model_path) { - return ResidualVectorQuantizer::Create(model_path); + const LyraModels& models) { + return ResidualVectorQuantizer::Create(models); } std::unique_ptr CreateGenerativeModel( - int num_output_features, const ghc::filesystem::path& model_path) { - return LyraGanModel::Create(model_path, num_output_features); + int num_output_features, const LyraModels& models) { + return LyraGanModel::Create(models, num_output_features); } std::unique_ptr CreateFeatureExtractor( - const ghc::filesystem::path& model_path) { - return SoundStreamEncoder::Create(model_path); + const LyraModels& models) { + return SoundStreamEncoder::Create(models); } std::unique_ptr CreatePacket(int num_header_bits, diff --git a/lyra/lyra_components.h b/lyra/lyra_components.h index 54ef2ea4..6eaebf09 100644 --- a/lyra/lyra_components.h +++ b/lyra/lyra_components.h @@ -25,18 +25,19 @@ #include "lyra/generative_model_interface.h" #include "lyra/packet_interface.h" #include "lyra/vector_quantizer_interface.h" +#include "lyra/lyra_embedded_models.h" namespace chromemedia { namespace codec { std::unique_ptr CreateQuantizer( - const ghc::filesystem::path& model_path); + const LyraModels& models); std::unique_ptr CreateGenerativeModel( - int num_output_features, const ghc::filesystem::path& model_path); + int num_output_features, const LyraModels& models); std::unique_ptr CreateFeatureExtractor( - const ghc::filesystem::path& model_path); + const LyraModels& models); std::unique_ptr CreatePacket(int num_header_bits, int num_quantized_bits); diff --git a/lyra/lyra_config.h b/lyra/lyra_config.h index 955787a6..f03ef95e 100644 --- a/lyra/lyra_config.h +++ b/lyra/lyra_config.h @@ -33,6 +33,7 @@ #include "glog/logging.h" // IWYU pragma: keep #include "include/ghc/filesystem.hpp" #include "lyra/lyra_config.pb.h" +#include "lyra/lyra_embedded_models.h" namespace chromemedia { namespace codec { @@ -118,7 +119,7 @@ std::vector GetAssets(); inline absl::Status AreParamsSupported( int sample_rate_hz, int num_channels, - const ghc::filesystem::path& model_path) { + const LyraModels& models) { if (!IsSampleRateSupported(sample_rate_hz)) { return absl::InvalidArgumentError(absl::StrFormat( "Sample rate %d Hz is not supported by codec.", sample_rate_hz)); @@ -128,36 +129,9 @@ inline absl::Status AreParamsSupported( "Number of channels %d is not supported by codec. It needs to be %d.", num_channels, kNumChannels)); } - for (auto asset : GetAssets()) { - std::error_code error; - const bool exists = - ghc::filesystem::exists(model_path / std::string(asset), error); - if (error) { - return absl::UnknownError( - absl::StrFormat("Error when probing for asset %s in %s: %s", asset, - model_path, error.message())); - } - if (!exists) { - return absl::InvalidArgumentError( - absl::StrFormat("Asset %s does not exist in %s.", asset, model_path)); - } - } - const ghc::filesystem::path lyra_config_proto_path = - model_path / "lyra_config.binarypb"; - std::error_code error; - const bool exists = ghc::filesystem::exists(lyra_config_proto_path, error); - if (error) { - return absl::UnknownError( - absl::StrFormat("Error when probing for asset %s: %s", - lyra_config_proto_path.string(), error.message())); - } third_party::lyra_codec::lyra::LyraConfig lyra_config; - if (exists) { - std::ifstream lyra_config_stream(lyra_config_proto_path.string()); - if (!lyra_config.ParseFromIstream(&lyra_config_stream)) { - return absl::UnknownError(absl::StrFormat( - "Error when parsing %s", lyra_config_proto_path.string())); - } + if (!lyra_config.ParseFromArray(models.lyra_config_proto.buffer, static_cast(models.lyra_config_proto.size))) { + return absl::UnknownError("Error when parsing lyra config proto"); } if (lyra_config.identifier() != kVersionMinor) { return absl::InvalidArgumentError(absl::StrFormat( diff --git a/lyra/lyra_config_test.cc b/lyra/lyra_config_test.cc index 036ba410..2863af29 100644 --- a/lyra/lyra_config_test.cc +++ b/lyra/lyra_config_test.cc @@ -25,35 +25,20 @@ #include "include/ghc/filesystem.hpp" #include "lyra/lyra_config.pb.h" +#include "lyra/model_coeffs/_models.h" + namespace chromemedia { namespace codec { namespace { class LyraConfigTest : public testing::Test { protected: - LyraConfigTest() - : source_model_path_(ghc::filesystem::current_path() / - "lyra/model_coeffs") {} + LyraConfigTest() : test_models_(GetEmbeddedLyraModels()) {} void SetUp() override { // Create a uniqe sub-directory so tests do not interfere with each other. const testing::TestInfo* const test_info = testing::UnitTest::GetInstance()->current_test_info(); - test_model_path_ = - ghc::filesystem::path(testing::TempDir()) / test_info->name(); - ghc::filesystem::create_directory(test_model_path_, error_code_); - ASSERT_FALSE(error_code_) << error_code_.message(); - ghc::filesystem::permissions( - test_model_path_, ghc::filesystem::perms::owner_write, - ghc::filesystem::perm_options::add, error_code_); - ASSERT_FALSE(error_code_) << error_code_.message(); - - // Copy model files. - ghc::filesystem::copy(source_model_path_, test_model_path_, - ghc::filesystem::copy_options::overwrite_existing | - ghc::filesystem::copy_options::recursive, - error_code_); - ASSERT_FALSE(error_code_) << error_code_.message(); } void DeleteFile(const ghc::filesystem::path& to_delete) { @@ -65,14 +50,8 @@ class LyraConfigTest : public testing::Test { ASSERT_FALSE(error_code_) << error_code_.message(); } - // Folder containing files from |source_model_path_| with some modifications - // to simulate various run-time conditions, e.g. missing some files, having - // mismatched files. - ghc::filesystem::path test_model_path_; + LyraModels test_models_; std::error_code error_code_; - - private: - const ghc::filesystem::path source_model_path_; }; TEST(LyraConfig, TestGetVersionString) { @@ -115,46 +94,20 @@ TEST_F(LyraConfigTest, BadBitrateNotSupported) { TEST_F(LyraConfigTest, GoodParamsSupported) { EXPECT_TRUE( - AreParamsSupported(kInternalSampleRateHz, kNumChannels, test_model_path_) + AreParamsSupported(kInternalSampleRateHz, kNumChannels, test_models_) .ok()); } TEST_F(LyraConfigTest, BadParamsNotSupported) { EXPECT_FALSE( - AreParamsSupported(/*sample_rate_hz=*/137, kNumChannels, test_model_path_) + AreParamsSupported(/*sample_rate_hz=*/137, kNumChannels, test_models_) .ok()); EXPECT_FALSE(AreParamsSupported(kInternalSampleRateHz, /*num_channels=*/-1, - test_model_path_) + test_models_) .ok()); } -TEST_F(LyraConfigTest, MissingAssetNotSupported) { - DeleteFile(test_model_path_ / "soundstream_encoder.tflite"); - EXPECT_FALSE( - AreParamsSupported(kInternalSampleRateHz, kNumChannels, test_model_path_) - .ok()); -} - -TEST_F(LyraConfigTest, MismatchedIdentifierNotSupported) { - // Replace the lyra_condfig.binarypb with one that contains an identifier - // that does not match |kVersionMinor|. - const ghc::filesystem::path lyra_config_proto_path = - test_model_path_ / "lyra_config.binarypb"; - DeleteFile(lyra_config_proto_path); - third_party::lyra_codec::lyra::LyraConfig lyra_config_proto; - lyra_config_proto.set_identifier(kVersionMinor + 100); - std::ofstream output_proto(lyra_config_proto_path.string(), - std::ofstream::out | std::ofstream::binary); - ASSERT_TRUE(output_proto.is_open()); - ASSERT_TRUE(lyra_config_proto.SerializeToOstream(&output_proto)); - output_proto.close(); - - EXPECT_FALSE( - AreParamsSupported(kInternalSampleRateHz, kNumChannels, test_model_path_) - .ok()); -} - } // namespace } // namespace codec } // namespace chromemedia diff --git a/lyra/lyra_decoder.cc b/lyra/lyra_decoder.cc index 73961b3b..ff780ace 100644 --- a/lyra/lyra_decoder.cc +++ b/lyra/lyra_decoder.cc @@ -94,9 +94,9 @@ int GetNumSamplesToGenerate(int num_samples_requested, std::unique_ptr LyraDecoder::Create( int sample_rate_hz, int num_channels, - const ghc::filesystem::path& model_path) { + const LyraModels& models) { absl::Status are_params_supported = - AreParamsSupported(sample_rate_hz, num_channels, model_path); + AreParamsSupported(sample_rate_hz, num_channels, models); if (!are_params_supported.ok()) { LOG(ERROR) << are_params_supported; return nullptr; @@ -114,7 +114,7 @@ std::unique_ptr LyraDecoder::Create( return nullptr; } // All internal components operate at |kInternalSampleRateHz|. - auto model = CreateGenerativeModel(kNumFeatures, model_path); + auto model = CreateGenerativeModel(kNumFeatures, models); if (model == nullptr) { LOG(ERROR) << "New model could not be instantiated."; return nullptr; @@ -133,7 +133,7 @@ std::unique_ptr LyraDecoder::Create( LOG(ERROR) << "Could not create Noise Estimator."; return nullptr; } - auto vector_quantizer = CreateQuantizer(model_path); + auto vector_quantizer = CreateQuantizer(models); if (vector_quantizer == nullptr) { LOG(ERROR) << "Could not create Vector Quantizer."; return nullptr; diff --git a/lyra/lyra_decoder.h b/lyra/lyra_decoder.h index 72ee1e4c..e415f022 100644 --- a/lyra/lyra_decoder.h +++ b/lyra/lyra_decoder.h @@ -30,6 +30,7 @@ #include "lyra/lyra_decoder_interface.h" #include "lyra/noise_estimator_interface.h" #include "lyra/vector_quantizer_interface.h" +#include "lyra/lyra_embedded_models.h" namespace chromemedia { namespace codec { @@ -46,14 +47,14 @@ class LyraDecoder : public LyraDecoderInterface { /// rates are 8000, 16000, 32000 and 48000. /// @param num_channels Desired number of channels. Currently only 1 is /// supported. - /// @param model_path Path to the model weights. The identifier in the + /// @param models Path to the model weights. The identifier in the /// lyra_config.binarypb has to coincide with the /// |kVersionMinor| constant in lyra_config.cc. /// @return A unique_ptr to a |LyraDecoder| if all desired params are /// supported. Else it returns a nullptr. static std::unique_ptr Create( int sample_rate_hz, int num_channels, - const ghc::filesystem::path& model_path); + const LyraModels& models); /// Parses a packet and prepares to decode samples from the payload. /// diff --git a/lyra/lyra_decoder_test.cc b/lyra/lyra_decoder_test.cc index dba72772..742e3071 100644 --- a/lyra/lyra_decoder_test.cc +++ b/lyra/lyra_decoder_test.cc @@ -44,6 +44,8 @@ #include "lyra/testing/mock_vector_quantizer.h" #include "lyra/vector_quantizer_interface.h" +#include "lyra/model_coeffs/_models.h" + namespace chromemedia { namespace codec { @@ -95,7 +97,7 @@ namespace { using testing::Exactly; using testing::Return; -static constexpr absl::string_view kExportedModelPath = "lyra/model_coeffs"; +static chromemedia::codec::LyraModels kExportedModels = GetEmbeddedLyraModels(); // Duration of pure packet loss concealment. inline int GetConcealmentDurationSamples() { @@ -142,7 +144,7 @@ class LyraDecoderTest quantized_zeros_(num_quantized_bits_, '0'), packet_(CreatePacket(kNumHeaderBits, num_quantized_bits_)), encoded_zeros_(packet_->PackQuantized(quantized_zeros_)), - model_path_(ghc::filesystem::current_path() / kExportedModelPath), + models_(GetEmbeddedLyraModels()), mock_features_(kNumFeatures), mock_noise_features_(kNumMelBins), mock_samples_(internal_num_samples_per_hop_), @@ -359,7 +361,7 @@ class LyraDecoderTest const std::string quantized_zeros_; const std::unique_ptr packet_; const std::vector encoded_zeros_; - const ghc::filesystem::path model_path_; + const LyraModels models_; std::vector mock_features_; std::vector mock_noise_features_; std::optional> mock_samples_; @@ -782,14 +784,14 @@ TEST_P(LyraDecoderTest, ArbitraryNumSamplesFadeFromComfortNoise) { TEST_P(LyraDecoderTest, ValidConfig) { EXPECT_NE( - LyraDecoder::Create(external_sample_rate_hz_, kNumChannels, model_path_), + LyraDecoder::Create(external_sample_rate_hz_, kNumChannels, models_), nullptr); } TEST_P(LyraDecoderTest, InvalidConfig) { for (const auto& invalid_num_channels : {-1, 0, 2}) { EXPECT_EQ(LyraDecoder::Create(external_sample_rate_hz_, - invalid_num_channels, model_path_), + invalid_num_channels, models_), nullptr); } } @@ -802,12 +804,12 @@ INSTANTIATE_TEST_SUITE_P( TEST(LyraDecoderCreate, InvalidCreateReturnsNullptr) { for (const auto& invalid_sample_rate : {0, -1, 16001}) { EXPECT_EQ(LyraDecoder::Create(invalid_sample_rate, kNumChannels, - kExportedModelPath), + kExportedModels), nullptr); } for (const auto& valid_sample_rate : kSupportedSampleRates) { EXPECT_EQ( - LyraDecoder::Create(valid_sample_rate, kNumChannels, "/does/not/exist"), + LyraDecoder::Create(valid_sample_rate, kNumChannels, GetInvalidEmbeddedLyraModels()), nullptr); } } diff --git a/lyra/lyra_embedded_models.h b/lyra/lyra_embedded_models.h new file mode 100644 index 00000000..560274e5 --- /dev/null +++ b/lyra/lyra_embedded_models.h @@ -0,0 +1,22 @@ +#ifndef LYRA_EMBEDDED_MODELS_H_ +#define LYRA_EMBEDDED_MODELS_H_ + +namespace chromemedia { +namespace codec { + +struct LyraModel { + const char* buffer; + size_t size; +}; + +struct LyraModels { + LyraModel lyra_config_proto; + LyraModel lyragan; + LyraModel quantizer; + LyraModel soundstream_encoder; +}; + +} // namespace codec +} // namespace chromemedia + +#endif // LYRA_EMBEDDED_MODELS_H_ diff --git a/lyra/lyra_encoder.cc b/lyra/lyra_encoder.cc index 3d774c8b..1c926c29 100644 --- a/lyra/lyra_encoder.cc +++ b/lyra/lyra_encoder.cc @@ -42,9 +42,9 @@ namespace codec { std::unique_ptr LyraEncoder::Create( int sample_rate_hz, int num_channels, int bitrate, bool enable_dtx, - const ghc::filesystem::path& model_path) { + const LyraModels& models) { absl::Status are_params_supported = - AreParamsSupported(sample_rate_hz, num_channels, model_path); + AreParamsSupported(sample_rate_hz, num_channels, models); if (!are_params_supported.ok()) { LOG(ERROR) << are_params_supported; return nullptr; @@ -64,13 +64,13 @@ std::unique_ptr LyraEncoder::Create( } } - auto feature_extractor = CreateFeatureExtractor(model_path); + auto feature_extractor = CreateFeatureExtractor(models); if (feature_extractor == nullptr) { LOG(ERROR) << "Could not create Features Extractor."; return nullptr; } - auto vector_quantizer = CreateQuantizer(model_path); + auto vector_quantizer = CreateQuantizer(models); if (vector_quantizer == nullptr) { LOG(ERROR) << "Could not create Vector Quantizer."; return nullptr; diff --git a/lyra/lyra_encoder.h b/lyra/lyra_encoder.h index 5aafd249..57e83e1c 100644 --- a/lyra/lyra_encoder.h +++ b/lyra/lyra_encoder.h @@ -29,6 +29,7 @@ #include "lyra/noise_estimator_interface.h" #include "lyra/resampler_interface.h" #include "lyra/vector_quantizer_interface.h" +#include "lyra/lyra_embedded_models.h" namespace chromemedia { namespace codec { @@ -53,14 +54,12 @@ class LyraEncoder : public LyraEncoderInterface { /// and 9200. /// @param enable_dtx Set to true if discontinuous transmission should be /// enabled. - /// @param model_path Path to the model weights. The identifier in the - /// lyra_config.textproto has to coincide with the - /// kVersionMinor constant in lyra_config.cc. + /// @param models /// @return A unique_ptr to a LyraEncoder if all desired params are supported. /// Else it returns a nullptr. static std::unique_ptr Create( int sample_rate_hz, int num_channels, int bitrate, bool enable_dtx, - const ghc::filesystem::path& model_path); + const LyraModels& models); /// Encodes the audio samples into a vector wrapped byte array. /// diff --git a/lyra/lyra_encoder_test.cc b/lyra/lyra_encoder_test.cc index 384c59e4..2f63cc75 100644 --- a/lyra/lyra_encoder_test.cc +++ b/lyra/lyra_encoder_test.cc @@ -301,54 +301,6 @@ TEST_P(LyraEncoderTest, MultipleEncodeCalls) { } } -TEST_P(LyraEncoderTest, GoodCreationParametersReturnNotNullptr) { - const auto valid_model_path = - ghc::filesystem::current_path() / "lyra/model_coeffs"; - - EXPECT_NE(nullptr, - LyraEncoder::Create(external_sample_rate_hz_, kNumChannels, - GetBitrate(num_quantized_bits_), - /*enable_dtx=*/false, valid_model_path)); - EXPECT_NE(nullptr, - LyraEncoder::Create(external_sample_rate_hz_, kNumChannels, - GetBitrate(num_quantized_bits_), - /*enable_dtx=*/true, valid_model_path)); -} - -TEST_P(LyraEncoderTest, BadCreationParametersReturnNullptr) { - const auto valid_model_path = - ghc::filesystem::current_path() / "lyra/model_coeffs"; - - EXPECT_EQ(nullptr, LyraEncoder::Create( - 0, kNumChannels, GetBitrate(num_quantized_bits_), - /*enable_dtx=*/false, valid_model_path)); - EXPECT_EQ(nullptr, LyraEncoder::Create( - 0, kNumChannels, GetBitrate(num_quantized_bits_), - /*enable_dtx=*/true, valid_model_path)); - EXPECT_EQ(nullptr, - LyraEncoder::Create(external_sample_rate_hz_, -3, - GetBitrate(num_quantized_bits_), - /*enable_dtx=*/false, valid_model_path)); - EXPECT_EQ(nullptr, - LyraEncoder::Create(external_sample_rate_hz_, -3, - GetBitrate(num_quantized_bits_), - /*enable_dtx=*/true, valid_model_path)); - EXPECT_EQ(nullptr, - LyraEncoder::Create(external_sample_rate_hz_, kNumChannels, -2, - /*enable_dtx=*/false, valid_model_path)); - EXPECT_EQ(nullptr, - LyraEncoder::Create(external_sample_rate_hz_, kNumChannels, -2, - /*enable_dtx=*/true, valid_model_path)); - EXPECT_EQ(nullptr, - LyraEncoder::Create(external_sample_rate_hz_, kNumChannels, - GetBitrate(num_quantized_bits_), - /*enable_dtx=*/false, "bad_model_path")); - EXPECT_EQ(nullptr, - LyraEncoder::Create(external_sample_rate_hz_, kNumChannels, - GetBitrate(num_quantized_bits_), - /*enable_dtx=*/true, "bad_model_path")); -} - TEST_P(LyraEncoderTest, SetBitrateSucceeds) { LyraEncoderPeer encoder_peer(std::move(mock_resampler_), std::move(mock_feature_extractor_), nullptr, diff --git a/lyra/lyra_gan_model.cc b/lyra/lyra_gan_model.cc index 13d6462a..fedd2cbf 100644 --- a/lyra/lyra_gan_model.cc +++ b/lyra/lyra_gan_model.cc @@ -34,9 +34,9 @@ namespace chromemedia { namespace codec { std::unique_ptr LyraGanModel::Create( - const ghc::filesystem::path& model_path, int num_features) { + const LyraModels& models, int num_features) { auto model = - TfLiteModelWrapper::Create(model_path / "lyragan.tflite", + TfLiteModelWrapper::Create(models.lyragan, /*use_xnn=*/true, /*int8_quantized=*/true); if (model == nullptr) { LOG(ERROR) << "Unable to create LyraGAN TFLite model wrapper."; diff --git a/lyra/lyra_gan_model.h b/lyra/lyra_gan_model.h index 38c92524..5a6d4c9e 100644 --- a/lyra/lyra_gan_model.h +++ b/lyra/lyra_gan_model.h @@ -34,7 +34,7 @@ class LyraGanModel : public GenerativeModel { public: // Returns a nullptr on failure. static std::unique_ptr Create( - const ghc::filesystem::path& model_path, int num_features); + const LyraModels& models, int num_features); ~LyraGanModel() override {} diff --git a/lyra/lyra_gan_model_test.cc b/lyra/lyra_gan_model_test.cc index 25ffcdd8..17ffbf1a 100644 --- a/lyra/lyra_gan_model_test.cc +++ b/lyra/lyra_gan_model_test.cc @@ -26,6 +26,8 @@ #include "include/ghc/filesystem.hpp" #include "lyra/lyra_config.h" +#include "model_coeffs/_models.h" + namespace chromemedia { namespace codec { namespace { @@ -34,7 +36,7 @@ class LyraGanModelTest : public testing::Test { protected: LyraGanModelTest() : model_(LyraGanModel::Create( - ghc::filesystem::current_path() / "lyra/model_coeffs", + GetEmbeddedLyraModels(), kNumFeatures)), features_(kNumFeatures) {} @@ -49,7 +51,7 @@ class LyraGanModelTest : public testing::Test { }; TEST_F(LyraGanModelTest, CreationFailsWithInvalidModelPath) { - EXPECT_EQ(LyraGanModel::Create("invalid/model/path", features_.size()), + EXPECT_EQ(LyraGanModel::Create(GetInvalidEmbeddedLyraModels(), features_.size()), nullptr); } diff --git a/lyra/lyra_integration_test.cc b/lyra/lyra_integration_test.cc index bad59c18..c530b595 100644 --- a/lyra/lyra_integration_test.cc +++ b/lyra/lyra_integration_test.cc @@ -34,6 +34,8 @@ #include "lyra/lyra_encoder.h" #include "lyra/wav_utils.h" +#include "model_coeffs/_models.h" + namespace chromemedia { namespace codec { namespace { @@ -48,8 +50,7 @@ class LyraIntegrationTest // This tests that decoded audio has similar features as the original. TEST_P(LyraIntegrationTest, DecodedAudioHasSimilarFeatures) { const ghc::filesystem::path wav_dir("lyra/testdata"); - const auto model_path = - ghc::filesystem::current_path() / std::string("lyra/model_coeffs"); + const LyraModels models = GetEmbeddedLyraModels(); const auto input_path = ghc::filesystem::current_path() / wav_dir / std::string(std::get<0>(GetParam())); @@ -62,11 +63,11 @@ TEST_P(LyraIntegrationTest, DecodedAudioHasSimilarFeatures) { std::unique_ptr encoder = LyraEncoder::Create(sample_rate_hz, input_wav_result->num_channels, GetBitrate(num_quantized_bits), - /*enable_dtx=*/false, model_path); + /*enable_dtx=*/false, models); ASSERT_NE(encoder, nullptr); std::unique_ptr decoder = LyraDecoder::Create( - sample_rate_hz, input_wav_result->num_channels, model_path); + sample_rate_hz, input_wav_result->num_channels, models); ASSERT_NE(decoder, nullptr); // Keep only 3 seconds to shorten the test duration. diff --git a/lyra/residual_vector_quantizer.cc b/lyra/residual_vector_quantizer.cc index 94041db7..b24a46bb 100644 --- a/lyra/residual_vector_quantizer.cc +++ b/lyra/residual_vector_quantizer.cc @@ -34,9 +34,9 @@ namespace chromemedia { namespace codec { std::unique_ptr ResidualVectorQuantizer::Create( - const ghc::filesystem::path& model_path) { + const LyraModels& models) { auto quantizer_model = - TfLiteModelWrapper::Create(model_path / "quantizer.tflite", + TfLiteModelWrapper::Create(models.quantizer, /*use_xnn=*/false, /*int8_quantized=*/false); if (quantizer_model == nullptr) { LOG(ERROR) << "Unable to create the quantizer TfLite model wrapper."; diff --git a/lyra/residual_vector_quantizer.h b/lyra/residual_vector_quantizer.h index 66fe6886..ea6db5c1 100644 --- a/lyra/residual_vector_quantizer.h +++ b/lyra/residual_vector_quantizer.h @@ -35,7 +35,7 @@ class ResidualVectorQuantizer : public VectorQuantizerInterface { public: // Returns nullptr if the TFLite model can't be built or allocated. static std::unique_ptr Create( - const ghc::filesystem::path& model_path); + const LyraModels& models); // Quantizes the features using vector quantization. std::optional Quantize(const std::vector& features, diff --git a/lyra/residual_vector_quantizer_test.cc b/lyra/residual_vector_quantizer_test.cc index d420ce94..15729860 100644 --- a/lyra/residual_vector_quantizer_test.cc +++ b/lyra/residual_vector_quantizer_test.cc @@ -28,6 +28,8 @@ #include "lyra/log_mel_spectrogram_extractor_impl.h" #include "lyra/lyra_config.h" +#include "model_coeffs/_models.h" + namespace chromemedia { namespace codec { namespace { @@ -36,8 +38,7 @@ class ResidualVectorQuantizerTest : public testing::TestWithParam { protected: ResidualVectorQuantizerTest() : num_quantized_bits_(GetParam()), - quantizer_(ResidualVectorQuantizer::Create( - ghc::filesystem::current_path() / "lyra/model_coeffs")), + quantizer_(ResidualVectorQuantizer::Create(GetEmbeddedLyraModels())), // These features correspond to silence run through the // SoundStreamEncoder using the SoundStreamEncoderTest. features_{ @@ -70,7 +71,7 @@ class ResidualVectorQuantizerTest : public testing::TestWithParam { }; TEST_P(ResidualVectorQuantizerTest, CreationFailsWithInvalidModelPath) { - EXPECT_EQ(ResidualVectorQuantizer::Create("invalid/model/path"), nullptr); + EXPECT_EQ(ResidualVectorQuantizer::Create(GetInvalidEmbeddedLyraModels()), nullptr); } TEST_P(ResidualVectorQuantizerTest, CreationSucceedsWithValidModelPath) { diff --git a/lyra/soundstream_encoder.cc b/lyra/soundstream_encoder.cc index 5c5c32c3..abe32561 100644 --- a/lyra/soundstream_encoder.cc +++ b/lyra/soundstream_encoder.cc @@ -34,9 +34,9 @@ namespace chromemedia { namespace codec { std::unique_ptr SoundStreamEncoder::Create( - const ghc::filesystem::path& model_path) { + const LyraModels& models) { auto model = - TfLiteModelWrapper::Create(model_path / "soundstream_encoder.tflite", + TfLiteModelWrapper::Create(models.soundstream_encoder, /*use_xnn=*/true, /*int8_quantized=*/true); if (model == nullptr) { LOG(ERROR) << "Unable to create SoundStream encoder TFLite model wrapper."; diff --git a/lyra/soundstream_encoder.h b/lyra/soundstream_encoder.h index f44d6e93..5db77684 100644 --- a/lyra/soundstream_encoder.h +++ b/lyra/soundstream_encoder.h @@ -36,7 +36,7 @@ class SoundStreamEncoder : public FeatureExtractorInterface { public: // Returns a nullptr on failure. static std::unique_ptr Create( - const ghc::filesystem::path& model_path); + const LyraModels& models); ~SoundStreamEncoder() override {} diff --git a/lyra/soundstream_encoder_test.cc b/lyra/soundstream_encoder_test.cc index 40f90e4d..c7b687b5 100644 --- a/lyra/soundstream_encoder_test.cc +++ b/lyra/soundstream_encoder_test.cc @@ -27,6 +27,8 @@ #include "include/ghc/filesystem.hpp" #include "lyra/lyra_config.h" +#include "model_coeffs/_models.h" + namespace chromemedia { namespace codec { namespace { @@ -34,14 +36,13 @@ namespace { class SoundStreamEncoderTest : public testing::Test { protected: SoundStreamEncoderTest() - : encoder_(SoundStreamEncoder::Create(ghc::filesystem::current_path() / - "lyra/model_coeffs")) {} + : encoder_(SoundStreamEncoder::Create(GetEmbeddedLyraModels())) {} std::unique_ptr encoder_; }; TEST_F(SoundStreamEncoderTest, CreationFailsWithInvalidModelPath) { - EXPECT_EQ(SoundStreamEncoder::Create("invalid/model/path"), nullptr); + EXPECT_EQ(SoundStreamEncoder::Create(GetInvalidEmbeddedLyraModels()), nullptr); } TEST_F(SoundStreamEncoderTest, CreationSucceedsWithValidModelPath) { diff --git a/lyra/testing/mock_vector_quantizer.h b/lyra/testing/mock_vector_quantizer.h index 202c7479..03c1cc65 100644 --- a/lyra/testing/mock_vector_quantizer.h +++ b/lyra/testing/mock_vector_quantizer.h @@ -31,9 +31,8 @@ class MockVectorQuantizer : public VectorQuantizerInterface { public: ~MockVectorQuantizer() override {} - MOCK_METHOD(std::optional, Quantize, - (const std::vector& features, int num_bits), - (const, override)); + MOCK_CONST_METHOD2_T(Quantize, + std::optional(const std::vector& features, int num_bits)); MOCK_METHOD(std::optional>, DecodeToLossyFeatures, (const std::string& quantized_features), (const, override)); diff --git a/lyra/tflite_model_wrapper.cc b/lyra/tflite_model_wrapper.cc index fc3b6b2f..2a23b261 100644 --- a/lyra/tflite_model_wrapper.cc +++ b/lyra/tflite_model_wrapper.cc @@ -34,12 +34,11 @@ namespace chromemedia { namespace codec { std::unique_ptr TfLiteModelWrapper::Create( - const ghc::filesystem::path& model_file, bool use_xnn, + const LyraModel& model_file, bool use_xnn, bool int8_quantized) { - auto model = tflite::FlatBufferModel::BuildFromFile(model_file.c_str()); + auto model = tflite::FlatBufferModel::BuildFromBuffer(model_file.buffer, model_file.size); if (model == nullptr) { - LOG(ERROR) << "Could not build TFLite FlatBufferModel for file: " - << model_file; + LOG(ERROR) << "Could not build TFLite FlatBufferModel"; return nullptr; } @@ -55,7 +54,7 @@ std::unique_ptr TfLiteModelWrapper::Create( std::unique_ptr interpreter; if (builder(&interpreter) != kTfLiteOk) { - LOG(ERROR) << "Could not build TFLite Interpreter for file: " << model_file; + LOG(ERROR) << "Could not build TFLite Interpreter"; return nullptr; } @@ -85,8 +84,7 @@ std::unique_ptr TfLiteModelWrapper::Create( // End of XNNPack delegate creation. if (interpreter->AllocateTensors() != kTfLiteOk) { - LOG(ERROR) << "Could not allocate quantize TFLite tensors for file: " - << model_file; + LOG(ERROR) << "Could not allocate quantize TFLite tensors"; return nullptr; } diff --git a/lyra/tflite_model_wrapper.h b/lyra/tflite_model_wrapper.h index 87aba3b6..1b18664c 100644 --- a/lyra/tflite_model_wrapper.h +++ b/lyra/tflite_model_wrapper.h @@ -25,6 +25,7 @@ #include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/model_builder.h" #include "tensorflow/lite/signature_runner.h" +#include "lyra_embedded_models.h" namespace chromemedia { namespace codec { @@ -32,7 +33,7 @@ namespace codec { class TfLiteModelWrapper { public: static std::unique_ptr Create( - const ghc::filesystem::path& model_file, bool use_xnn, + const LyraModel& model_file, bool use_xnn, bool int8_quantized); bool Invoke(); diff --git a/lyra/tflite_model_wrapper_test.cc b/lyra/tflite_model_wrapper_test.cc index dd4219bd..2ca02f40 100644 --- a/lyra/tflite_model_wrapper_test.cc +++ b/lyra/tflite_model_wrapper_test.cc @@ -25,12 +25,15 @@ #include "gtest/gtest.h" #include "include/ghc/filesystem.hpp" +#include "model_coeffs/_models.h" + namespace chromemedia { namespace codec { namespace { TEST(TfLiteModelWrapperTest, CreateFailsWithInvalidModelFile) { - EXPECT_EQ(TfLiteModelWrapper::Create("invalid/model/path", true, false), + const LyraModel invalid = {nullptr, 0}; + EXPECT_EQ(TfLiteModelWrapper::Create(invalid, true, false), nullptr); } @@ -38,8 +41,9 @@ class TfLiteModelWrapperTest : public testing::TestWithParam {}; TEST_P(TfLiteModelWrapperTest, CreateSucceedsAndMethodsRun) { const bool int8_quantized = GetParam(); + const auto lyragan = GetEmbeddedLyraModels().lyragan; auto model_wrapper = TfLiteModelWrapper::Create( - ghc::filesystem::current_path() / "lyra/model_coeffs/lyragan.tflite", + lyragan, true, int8_quantized); ASSERT_NE(model_wrapper, nullptr); absl::Span input = model_wrapper->get_input_tensor(0); diff --git a/models_to_header.py b/models_to_header.py new file mode 100644 index 00000000..2acc8379 --- /dev/null +++ b/models_to_header.py @@ -0,0 +1,51 @@ +# This program will convert the Lyra model files residing at lyra/model_coeffs +# into C++ header files, so that the models can be bundled within the binary itself +# instead of loading them from external .tflite files. + +import argparse +import sys + +def bin2header(data, var_name='var'): + out = [] + out.append('const unsigned char {var_name}[] = {{'.format(var_name=var_name)) + l = [ data[i:i+12] for i in range(0, len(data), 12) ] + for i, x in enumerate(l): + line = ', '.join([ '0x{val:02x}'.format(val=c) for c in x ]) + out.append(' {line}{end_comma}'.format(line=line, end_comma=',' if i(lyra_config_proto), lyra_config_proto_len }, + { reinterpret_cast(lyragan), lyragan_len }, + { reinterpret_cast(quantizer), quantizer_len }, + { reinterpret_cast(soundstream_encoder), soundstream_encoder_len }, + }; +} +inline chromemedia::codec::LyraModels GetInvalidEmbeddedLyraModels() { + return {{nullptr, 0}, {nullptr, 0}, {nullptr, 0}, {nullptr, 0}}; +} +#endif +""" + +with open("lyra/model_coeffs/_models.h", 'w') as f: + f.write(out)